mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Make register_dtype() accept any field containers (#1225)
* Make register_dtype() accept any field containers * Add a test for programmatic dtype registration
This commit is contained in:
parent
b48d4a01ca
commit
d1db2ccfdf
@ -18,9 +18,9 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <initializer_list>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <typeindex>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@ -1057,7 +1057,7 @@ struct field_descriptor {
|
||||
};
|
||||
|
||||
inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
const std::initializer_list<field_descriptor>& fields,
|
||||
any_container<field_descriptor> fields,
|
||||
const std::type_info& tinfo, ssize_t itemsize,
|
||||
bool (*direct_converter)(PyObject *, void *&)) {
|
||||
|
||||
@ -1066,7 +1066,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
pybind11_fail("NumPy: dtype is already registered");
|
||||
|
||||
list names, formats, offsets;
|
||||
for (auto field : fields) {
|
||||
for (auto field : *fields) {
|
||||
if (!field.descr)
|
||||
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
|
||||
field.name + "` @ " + tinfo.name());
|
||||
@ -1083,7 +1083,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
// - https://github.com/numpy/numpy/pull/7798
|
||||
// Because of this, we won't use numpy's logic to generate buffer format
|
||||
// strings and will just do it ourselves.
|
||||
std::vector<field_descriptor> ordered_fields(fields);
|
||||
std::vector<field_descriptor> ordered_fields(std::move(fields));
|
||||
std::sort(ordered_fields.begin(), ordered_fields.end(),
|
||||
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
|
||||
ssize_t offset = 0;
|
||||
@ -1130,8 +1130,8 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
|
||||
return format_str;
|
||||
}
|
||||
|
||||
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
|
||||
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
|
||||
static void register_dtype(any_container<field_descriptor> fields) {
|
||||
register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
|
||||
sizeof(T), &direct_converter);
|
||||
}
|
||||
|
||||
@ -1204,7 +1204,8 @@ private:
|
||||
|
||||
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
|
||||
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
|
||||
({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
||||
(::std::vector<::pybind11::detail::field_descriptor> \
|
||||
{PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
|
||||
@ -1225,7 +1226,8 @@ private:
|
||||
|
||||
#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
|
||||
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
|
||||
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
|
||||
(::std::vector<::pybind11::detail::field_descriptor> \
|
||||
{PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
|
||||
|
||||
#endif // __CLION_IDE__
|
||||
|
||||
|
@ -244,6 +244,9 @@ py::list test_dtype_ctors() {
|
||||
return list;
|
||||
}
|
||||
|
||||
struct A {};
|
||||
struct B {};
|
||||
|
||||
TEST_SUBMODULE(numpy_dtypes, m) {
|
||||
try { py::module::import("numpy"); }
|
||||
catch (...) { return; }
|
||||
@ -271,6 +274,15 @@ TEST_SUBMODULE(numpy_dtypes, m) {
|
||||
// struct NotPOD { std::string v; NotPOD() : v("hi") {}; };
|
||||
// PYBIND11_NUMPY_DTYPE(NotPOD, v);
|
||||
|
||||
// Check that dtypes can be registered programmatically, both from
|
||||
// initializer lists of field descriptors and from other containers.
|
||||
py::detail::npy_format_descriptor<A>::register_dtype(
|
||||
{}
|
||||
);
|
||||
py::detail::npy_format_descriptor<B>::register_dtype(
|
||||
std::vector<py::detail::field_descriptor>{}
|
||||
);
|
||||
|
||||
// test_recarray, test_scalar_conversion
|
||||
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
|
||||
m.def("create_rec_packed", &create_recarray<PackedStruct>);
|
||||
|
Loading…
Reference in New Issue
Block a user