diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index b180cb296..af465a17d 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -81,14 +81,18 @@ struct numpy_type_info { struct numpy_internals { std::unordered_map registered_dtypes; - template numpy_type_info *get_type_info(bool throw_if_missing = true) { - auto it = registered_dtypes.find(std::type_index(typeid(T))); + numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(tinfo)); if (it != registered_dtypes.end()) return &(it->second); if (throw_if_missing) - pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name()); + pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); return nullptr; } + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + return get_type_info(typeid(typename std::remove_cv::type), throw_if_missing); + } }; inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { @@ -686,6 +690,62 @@ struct field_descriptor { dtype descr; }; +inline PYBIND11_NOINLINE void register_structured_dtype( + const std::initializer_list& fields, + const std::type_info& tinfo, size_t itemsize, + bool (*direct_converter)(PyObject *, void *&)) +{ + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(tinfo, false)) + pybind11_fail("NumPy: dtype is already registered"); + + list names, formats, offsets; + for (auto field : fields) { + if (!field.descr) + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + tinfo.name()); + names.append(PYBIND11_STR_TYPE(field.name)); + formats.append(field.descr); + offsets.append(pybind11::int_(field.offset)); + } + auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); + + // There is an existing bug in NumPy (as of v1.11): trailing bytes are + // not encoded explicitly into the format string. This will supposedly + // get fixed in v1.12; for further details, see these: + // - https://github.com/numpy/numpy/issues/7797 + // - https://github.com/numpy/numpy/pull/7798 + // Because of this, we won't use numpy's logic to generate buffer format + // strings and will just do it ourselves. + std::vector ordered_fields(fields); + std::sort(ordered_fields.begin(), ordered_fields.end(), + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); + size_t offset = 0; + std::ostringstream oss; + oss << "T{"; + for (auto& field : ordered_fields) { + if (field.offset > offset) + oss << (field.offset - offset) << 'x'; + // note that '=' is required to cover the case of unaligned fields + oss << '=' << field.format << ':' << field.name << ':'; + offset = field.offset + field.size; + } + if (itemsize > offset) + oss << (itemsize - offset) << 'x'; + oss << '}'; + auto format_str = oss.str(); + + // Sanity check: verify that NumPy properly parses our buffer format string + auto& api = npy_api::get(); + auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); + if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) + pybind11_fail("NumPy: invalid buffer descriptor!"); + + auto tindex = std::type_index(tinfo); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); +} + template struct npy_format_descriptor::value>> { static PYBIND11_DESCR name() { return _("struct"); } @@ -699,56 +759,9 @@ struct npy_format_descriptor::value>> { return format_str; } - static void register_dtype(std::initializer_list fields) { - auto& numpy_internals = get_numpy_internals(); - if (numpy_internals.get_type_info(false)) - pybind11_fail("NumPy: dtype is already registered"); - - list names, formats, offsets; - for (auto field : fields) { - if (!field.descr) - pybind11_fail(std::string("NumPy: unsupported field dtype: `") + - field.name + "` @ " + typeid(T).name()); - names.append(PYBIND11_STR_TYPE(field.name)); - formats.append(field.descr); - offsets.append(pybind11::int_(field.offset)); - } - auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); - - // There is an existing bug in NumPy (as of v1.11): trailing bytes are - // not encoded explicitly into the format string. This will supposedly - // get fixed in v1.12; for further details, see these: - // - https://github.com/numpy/numpy/issues/7797 - // - https://github.com/numpy/numpy/pull/7798 - // Because of this, we won't use numpy's logic to generate buffer format - // strings and will just do it ourselves. - std::vector ordered_fields(fields); - std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); - size_t offset = 0; - std::ostringstream oss; - oss << "T{"; - for (auto& field : ordered_fields) { - if (field.offset > offset) - oss << (field.offset - offset) << 'x'; - // note that '=' is required to cover the case of unaligned fields - oss << '=' << field.format << ':' << field.name << ':'; - offset = field.offset + field.size; - } - if (sizeof(T) > offset) - oss << (sizeof(T) - offset) << 'x'; - oss << '}'; - auto format_str = oss.str(); - - // Sanity check: verify that NumPy properly parses our buffer format string - auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1)); - if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) - pybind11_fail("NumPy: invalid buffer descriptor!"); - - auto tindex = std::type_index(typeid(T)); - numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; - get_internals().direct_conversions[tindex].push_back(direct_converter); + static void register_dtype(const std::initializer_list& fields) { + register_structured_dtype(fields, typeid(typename std::remove_cv::type), + sizeof(T), &direct_converter); } private: