diff --git a/include/pybind11/detail/common.h b/include/pybind11/detail/common.h index ea09bb3fd..b49313e44 100644 --- a/include/pybind11/detail/common.h +++ b/include/pybind11/detail/common.h @@ -992,6 +992,7 @@ constexpr const char struct error_scope { PyObject *type, *value, *trace; error_scope() { PyErr_Fetch(&type, &value, &trace); } + error_scope(const error_scope &) = delete; ~error_scope() { PyErr_Restore(type, value, trace); } }; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index a456ff985..dc8817532 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -689,18 +689,16 @@ public: PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); explicit dtype(const buffer_info &info) { - dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); + dtype descr(_dtype_from_pep3118()(pybind11::str(info.format))); // If info.itemsize == 0, use the value calculated from the format string m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize()) .release() .ptr(); } - explicit dtype(const std::string &format) { - m_ptr = from_args(pybind11::str(format)).release().ptr(); - } + explicit dtype(const std::string &format) : dtype(from_args(pybind11::str(format))) {} - explicit dtype(const char *format) : dtype(std::string(format)) {} + explicit dtype(const char *format) : dtype(from_args(pybind11::str(format))) {} dtype(list names, list formats, list offsets, ssize_t itemsize) { dict args; @@ -711,6 +709,13 @@ public: m_ptr = from_args(std::move(args)).release().ptr(); } + explicit dtype(int typenum) + : object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) { + if (m_ptr == nullptr) { + throw error_already_set(); + } + } + /// This is essentially the same as calling numpy.dtype(args) in Python. static dtype from_args(object args) { PyObject *ptr = nullptr; @@ -745,6 +750,23 @@ public: return detail::array_descriptor_proxy(m_ptr)->type; } + /// type number of dtype. + ssize_t num() const { + // Note: The signature, `dtype::num` follows the naming of NumPy's public + // Python API (i.e., ``dtype.num``), rather than its internal + // C API (``PyArray_Descr::type_num``). + return detail::array_descriptor_proxy(m_ptr)->type_num; + } + + /// Single character for byteorder + char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; } + + /// Alignment of the data type + int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; } + + /// Flags for the array descriptor + char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; } + private: static object _dtype_from_pep3118() { static PyObject *obj = module_::import("numpy.core._internal") @@ -763,7 +785,7 @@ private: } struct field_descr { - PYBIND11_STR_TYPE name; + pybind11::str name; object format; pybind11::int_ offset; }; @@ -778,7 +800,7 @@ private: continue; } field_descriptors.push_back( - {(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset}); + {(pybind11::str) name, format.strip_padding(format.itemsize()), offset}); } std::sort(field_descriptors.begin(), @@ -1452,7 +1474,7 @@ PYBIND11_NOINLINE void register_structured_dtype(any_container pybind11_fail(std::string("NumPy: unsupported field dtype: `") + field.name + "` @ " + tinfo.name()); } - names.append(PYBIND11_STR_TYPE(field.name)); + names.append(pybind11::str(field.name)); formats.append(field.descr); offsets.append(pybind11::int_(field.offset)); } diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index ba0fda0a8..7d52774c8 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1588,7 +1588,8 @@ public: } pybind11_fail("Unable to get capsule context"); } - void *ptr = PyCapsule_GetPointer(o, nullptr); + const char *name = get_name_in_error_scope(o); + void *ptr = PyCapsule_GetPointer(o, name); if (ptr == nullptr) { throw error_already_set(); } @@ -1602,7 +1603,8 @@ public: explicit capsule(void (*destructor)()) { m_ptr = PyCapsule_New(reinterpret_cast(destructor), nullptr, [](PyObject *o) { - auto destructor = reinterpret_cast(PyCapsule_GetPointer(o, nullptr)); + const char *name = get_name_in_error_scope(o); + auto destructor = reinterpret_cast(PyCapsule_GetPointer(o, name)); if (destructor == nullptr) { throw error_already_set(); } @@ -1637,7 +1639,33 @@ public: } } - const char *name() const { return PyCapsule_GetName(m_ptr); } + const char *name() const { + const char *name = PyCapsule_GetName(m_ptr); + if ((name == nullptr) && PyErr_Occurred()) { + throw error_already_set(); + } + return name; + } + + /// Replaces a capsule's name *without* calling the destructor on the existing one. + void set_name(const char *new_name) { + if (PyCapsule_SetName(m_ptr, new_name) != 0) { + throw error_already_set(); + } + } + +private: + static const char *get_name_in_error_scope(PyObject *o) { + error_scope error_guard; + + const char *name = PyCapsule_GetName(o); + if ((name == nullptr) && PyErr_Occurred()) { + // write out and consume error raised by call to PyCapsule_GetName + PyErr_WriteUnraisable(o); + } + + return name; + } }; class tuple : public object { diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index dd5b123dc..7de36f2fe 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -291,6 +291,7 @@ py::list test_dtype_ctors() { list.append(py::dtype(names, formats, offsets, 20)); list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1))); list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1))); + list.append(py::dtype(py::detail::npy_api::NPY_DOUBLE_)); return list; } @@ -440,6 +441,34 @@ TEST_SUBMODULE(numpy_dtypes, m) { } return list; }); + m.def("test_dtype_num", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).num()); + } + return list; + }); + m.def("test_dtype_byteorder", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).byteorder()); + } + return list; + }); + m.def("test_dtype_alignment", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).alignment()); + } + return list; + }); + m.def("test_dtype_flags", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).flags()); + } + return list; + }); m.def("test_dtype_methods", []() { py::list list; auto dt1 = py::dtype::of(); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 7df60583f..fcfd587b1 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -160,6 +160,7 @@ def test_dtype(simple_dtype): d1, np.dtype("uint32"), d2, + np.dtype("d"), ] assert m.test_dtype_methods() == [ @@ -175,8 +176,13 @@ def test_dtype(simple_dtype): np.zeros(1, m.trailing_padding_dtype()) ) + expected_chars = "bhilqBHILQefdgFDG?MmO" assert m.test_dtype_kind() == list("iiiiiuuuuuffffcccbMmO") - assert m.test_dtype_char_() == list("bhilqBHILQefdgFDG?MmO") + assert m.test_dtype_char_() == list(expected_chars) + assert m.test_dtype_num() == [np.dtype(ch).num for ch in expected_chars] + assert m.test_dtype_byteorder() == [np.dtype(ch).byteorder for ch in expected_chars] + assert m.test_dtype_alignment() == [np.dtype(ch).alignment for ch in expected_chars] + assert m.test_dtype_flags() == [chr(np.dtype(ch).flags) for ch in expected_chars] def test_recarray(simple_dtype, packed_dtype): diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index b859497b8..d1e9b81a7 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -159,6 +159,15 @@ TEST_SUBMODULE(pytypes, m) { return py::capsule([]() { py::print("destructing capsule"); }); }); + m.def("return_renamed_capsule_with_destructor", []() { + py::print("creating capsule"); + auto cap = py::capsule([]() { py::print("destructing capsule"); }); + static const char *capsule_name = "test_name1"; + py::print("renaming capsule"); + cap.set_name(capsule_name); + return cap; + }); + m.def("return_capsule_with_destructor_2", []() { py::print("creating capsule"); return py::capsule((void *) 1234, [](void *ptr) { @@ -166,6 +175,17 @@ TEST_SUBMODULE(pytypes, m) { }); }); + m.def("return_renamed_capsule_with_destructor_2", []() { + py::print("creating capsule"); + auto cap = py::capsule((void *) 1234, [](void *ptr) { + py::print("destructing capsule: {}"_s.format((size_t) ptr)); + }); + static const char *capsule_name = "test_name2"; + py::print("renaming capsule"); + cap.set_name(capsule_name); + return cap; + }); + m.def("return_capsule_with_name_and_destructor", []() { auto capsule = py::capsule((void *) 12345, "pointer type description", [](PyObject *ptr) { if (ptr) { diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 85afb9423..5c715ada6 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -195,6 +195,19 @@ def test_capsule(capture): """ ) + with capture: + a = m.return_renamed_capsule_with_destructor() + del a + pytest.gc_collect() + assert ( + capture.unordered + == """ + creating capsule + renaming capsule + destructing capsule + """ + ) + with capture: a = m.return_capsule_with_destructor_2() del a @@ -207,6 +220,19 @@ def test_capsule(capture): """ ) + with capture: + a = m.return_renamed_capsule_with_destructor_2() + del a + pytest.gc_collect() + assert ( + capture.unordered + == """ + creating capsule + renaming capsule + destructing capsule: 1234 + """ + ) + with capture: a = m.return_capsule_with_name_and_destructor() del a