From 2184f6d4d64b4631c943b4c92d35bcd849da51ad Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 13:52:32 +0000 Subject: [PATCH] NumPy dtypes are now shared across extensions --- include/pybind11/common.h | 1 + include/pybind11/numpy.h | 78 +++++++++++++++++++++++++------------- tests/test_numpy_dtypes.py | 2 +- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index b5434d04a..27cd47bef 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -323,6 +323,7 @@ struct internals { std::unordered_set, overload_hash> inactive_overload_cache; std::unordered_map> direct_conversions; std::forward_list registered_exception_translators; + std::unordered_map shared_data; #if defined(WITH_THREAD) decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x PyInterpreterState *istate = nullptr; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index da04c62a8..19bff6359 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -21,6 +21,7 @@ #include #include #include +#include #if defined(_MSC_VER) # pragma warning(push) @@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy { PyObject *base; }; +struct numpy_type_info { + PyObject* dtype_ptr; + std::string format_str; +}; + +struct numpy_internals { + std::unordered_map registered_dtypes; + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(typeid(T))); + if (it != registered_dtypes.end()) + return &(it->second); + if (throw_if_missing) + pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name()); + return nullptr; + } +}; + +inline PYBIND11_NOINLINE numpy_internals* load_numpy_internals() { + auto& shared_data = detail::get_internals().shared_data; + auto it = shared_data.find("numpy_internals"); + if (it != shared_data.end()) + return (numpy_internals *)it->second; + auto ptr = new numpy_internals(); + shared_data["numpy_internals"] = ptr; + return ptr; +} + +inline numpy_internals& get_numpy_internals() { + static numpy_internals* ptr = load_numpy_internals(); + return *ptr; +} + struct npy_api { enum constants { NPY_C_CONTIGUOUS_ = 0x0001, @@ -661,30 +695,29 @@ struct npy_format_descriptor::value>> { static PYBIND11_DESCR name() { return _("struct"); } static pybind11::dtype dtype() { - if (!dtype_ptr) - pybind11_fail("NumPy: unsupported buffer format!"); - return object(dtype_ptr, true); + return object(dtype_ptr(), true); } static std::string format() { - if (!dtype_ptr) - pybind11_fail("NumPy: unsupported buffer format!"); + static auto format_str = get_numpy_internals().get_type_info(true)->format_str; return format_str; } static void register_dtype(std::initializer_list fields) { - if (dtype_ptr) + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(false)) pybind11_fail("NumPy: dtype is already registered"); list names, formats, offsets; for (auto field : fields) { if (!field.descr) - pybind11_fail("NumPy: unsupported field dtype"); + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + typeid(T).name()); names.append(PYBIND11_STR_TYPE(field.name)); formats.append(field.descr); offsets.append(pybind11::int_(field.offset)); } - dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); + auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); // There is an existing bug in NumPy (as of v1.11): trailing bytes are // not encoded explicitly into the format string. This will supposedly @@ -695,9 +728,7 @@ struct npy_format_descriptor::value>> { // strings and will just do it ourselves. std::vector ordered_fields(fields); std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor &a, const field_descriptor &b) { - return a.offset < b.offset; - }); + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); size_t offset = 0; std::ostringstream oss; oss << "T{"; @@ -711,44 +742,39 @@ struct npy_format_descriptor::value>> { if (sizeof(T) > offset) oss << (sizeof(T) - offset) << 'x'; oss << '}'; - format_str = oss.str(); + auto format_str = oss.str(); // Sanity check: verify that NumPy properly parses our buffer format string auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1)); + auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1)); if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) pybind11_fail("NumPy: invalid buffer descriptor!"); - register_direct_converter(); + auto tindex = std::type_index(typeid(T)); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); } private: - static std::string format_str; - static PyObject* dtype_ptr; + static PyObject* dtype_ptr() { + static PyObject* ptr = get_numpy_internals().get_type_info(true)->dtype_ptr; + return ptr; + } static bool direct_converter(PyObject *obj, void*& value) { auto& api = npy_api::get(); if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) return false; if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { - if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { + if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { value = ((PyVoidScalarObject_Proxy *) obj)->obval; return true; } } return false; } - - static void register_direct_converter() { - get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter); - } }; -template -std::string npy_format_descriptor::value>>::format_str; -template -PyObject* npy_format_descriptor::value>>::dtype_ptr = nullptr; - #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ ::pybind11::detail::field_descriptor { \ Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index b4e6d71f2..c0d6ec292 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -18,7 +18,7 @@ def test_format_descriptors(): with pytest.raises(RuntimeError) as excinfo: get_format_unbound() - assert 'unsupported buffer format' in str(excinfo.value) + assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) assert print_format_descriptors() == [ "T{=?:x:3x=I:y:=f:z:}",