diff --git a/docs/advanced/misc.rst b/docs/advanced/misc.rst index 2968f8ac1..b0719065f 100644 --- a/docs/advanced/misc.rst +++ b/docs/advanced/misc.rst @@ -149,6 +149,25 @@ accessed by multiple extension modules: ... }; +Note also that it is possible (although would rarely be required) to share arbitrary +C++ objects between extension modules at runtime. Internal library data is shared +between modules using capsule machinery [#f6]_ which can be also utilized for +storing, modifying and accessing user-defined data. Note that an extension module +will "see" other extensions' data if and only if they were built with the same +pybind11 version. Consider the following example: + +.. code-block:: cpp + + auto data = (MyData *) py::get_shared_data("mydata"); + if (!data) + data = (MyData *) py::set_shared_data("mydata", new MyData(42)); + +If the above snippet was used in several separately compiled extension modules, +the first one to be imported would create a ``MyData`` instance and associate +a ``"mydata"`` key with a pointer to it. Extensions that are imported later +would be then able to access the data behind the same pointer. + +.. [#f6] https://docs.python.org/3/extending/extending.html#using-capsules Generating documentation using Sphinx diff --git a/include/pybind11/common.h b/include/pybind11/common.h index b5434d04a..62198c341 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -323,6 +323,7 @@ struct internals { std::unordered_set, overload_hash> inactive_overload_cache; std::unordered_map> direct_conversions; std::forward_list registered_exception_translators; + std::unordered_map shared_data; // Custom data to be shared across extensions #if defined(WITH_THREAD) decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x PyInterpreterState *istate = nullptr; @@ -427,6 +428,35 @@ inline void ignore_unused(const int *) { } NAMESPACE_END(detail) +/// Returns a named pointer that is shared among all extension modules (using the same +/// pybind11 version) running in the current interpreter. Names starting with underscores +/// are reserved for internal usage. Returns `nullptr` if no matching entry was found. +inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) { + auto& internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + return it != internals.shared_data.end() ? it->second : nullptr; +} + +/// Set the shared data that can be later recovered by `get_shared_data()`. +inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) { + detail::get_internals().shared_data[name] = data; + return data; +} + +/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if +/// such entry exists. Otherwise, a new object of default-constructible type `T` is +/// added to the shared data under the given name and a reference to it is returned. +template T& get_or_create_shared_data(const std::string& name) { + auto& internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr); + if (!ptr) { + ptr = new T(); + internals.shared_data[name] = ptr; + } + return *ptr; +} + /// Fetch and hold an error which was already set in Python class error_already_set : public std::runtime_error { public: diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index da04c62a8..af465a17d 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -21,6 +21,7 @@ #include #include #include +#include #if defined(_MSC_VER) # pragma warning(push) @@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy { PyObject *base; }; +struct numpy_type_info { + PyObject* dtype_ptr; + std::string format_str; +}; + +struct numpy_internals { + std::unordered_map registered_dtypes; + + numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(tinfo)); + if (it != registered_dtypes.end()) + return &(it->second); + if (throw_if_missing) + pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); + return nullptr; + } + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + return get_type_info(typeid(typename std::remove_cv::type), throw_if_missing); + } +}; + +inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { + ptr = &get_or_create_shared_data("_numpy_internals"); +} + +inline numpy_internals& get_numpy_internals() { + static numpy_internals* ptr = nullptr; + if (!ptr) + load_numpy_internals(ptr); + return *ptr; +} + struct npy_api { enum constants { NPY_C_CONTIGUOUS_ = 0x0001, @@ -656,99 +690,100 @@ struct field_descriptor { dtype descr; }; +inline PYBIND11_NOINLINE void register_structured_dtype( + const std::initializer_list& fields, + const std::type_info& tinfo, size_t itemsize, + bool (*direct_converter)(PyObject *, void *&)) +{ + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(tinfo, false)) + pybind11_fail("NumPy: dtype is already registered"); + + list names, formats, offsets; + for (auto field : fields) { + if (!field.descr) + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + tinfo.name()); + names.append(PYBIND11_STR_TYPE(field.name)); + formats.append(field.descr); + offsets.append(pybind11::int_(field.offset)); + } + auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); + + // There is an existing bug in NumPy (as of v1.11): trailing bytes are + // not encoded explicitly into the format string. This will supposedly + // get fixed in v1.12; for further details, see these: + // - https://github.com/numpy/numpy/issues/7797 + // - https://github.com/numpy/numpy/pull/7798 + // Because of this, we won't use numpy's logic to generate buffer format + // strings and will just do it ourselves. + std::vector ordered_fields(fields); + std::sort(ordered_fields.begin(), ordered_fields.end(), + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); + size_t offset = 0; + std::ostringstream oss; + oss << "T{"; + for (auto& field : ordered_fields) { + if (field.offset > offset) + oss << (field.offset - offset) << 'x'; + // note that '=' is required to cover the case of unaligned fields + oss << '=' << field.format << ':' << field.name << ':'; + offset = field.offset + field.size; + } + if (itemsize > offset) + oss << (itemsize - offset) << 'x'; + oss << '}'; + auto format_str = oss.str(); + + // Sanity check: verify that NumPy properly parses our buffer format string + auto& api = npy_api::get(); + auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); + if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) + pybind11_fail("NumPy: invalid buffer descriptor!"); + + auto tindex = std::type_index(tinfo); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); +} + template struct npy_format_descriptor::value>> { static PYBIND11_DESCR name() { return _("struct"); } static pybind11::dtype dtype() { - if (!dtype_ptr) - pybind11_fail("NumPy: unsupported buffer format!"); - return object(dtype_ptr, true); + return object(dtype_ptr(), true); } static std::string format() { - if (!dtype_ptr) - pybind11_fail("NumPy: unsupported buffer format!"); + static auto format_str = get_numpy_internals().get_type_info(true)->format_str; return format_str; } - static void register_dtype(std::initializer_list fields) { - if (dtype_ptr) - pybind11_fail("NumPy: dtype is already registered"); - - list names, formats, offsets; - for (auto field : fields) { - if (!field.descr) - pybind11_fail("NumPy: unsupported field dtype"); - names.append(PYBIND11_STR_TYPE(field.name)); - formats.append(field.descr); - offsets.append(pybind11::int_(field.offset)); - } - dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); - - // There is an existing bug in NumPy (as of v1.11): trailing bytes are - // not encoded explicitly into the format string. This will supposedly - // get fixed in v1.12; for further details, see these: - // - https://github.com/numpy/numpy/issues/7797 - // - https://github.com/numpy/numpy/pull/7798 - // Because of this, we won't use numpy's logic to generate buffer format - // strings and will just do it ourselves. - std::vector ordered_fields(fields); - std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor &a, const field_descriptor &b) { - return a.offset < b.offset; - }); - size_t offset = 0; - std::ostringstream oss; - oss << "T{"; - for (auto& field : ordered_fields) { - if (field.offset > offset) - oss << (field.offset - offset) << 'x'; - // note that '=' is required to cover the case of unaligned fields - oss << '=' << field.format << ':' << field.name << ':'; - offset = field.offset + field.size; - } - if (sizeof(T) > offset) - oss << (sizeof(T) - offset) << 'x'; - oss << '}'; - format_str = oss.str(); - - // Sanity check: verify that NumPy properly parses our buffer format string - auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1)); - if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) - pybind11_fail("NumPy: invalid buffer descriptor!"); - - register_direct_converter(); + static void register_dtype(const std::initializer_list& fields) { + register_structured_dtype(fields, typeid(typename std::remove_cv::type), + sizeof(T), &direct_converter); } private: - static std::string format_str; - static PyObject* dtype_ptr; + static PyObject* dtype_ptr() { + static PyObject* ptr = get_numpy_internals().get_type_info(true)->dtype_ptr; + return ptr; + } static bool direct_converter(PyObject *obj, void*& value) { auto& api = npy_api::get(); if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) return false; if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { - if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { + if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { value = ((PyVoidScalarObject_Proxy *) obj)->obval; return true; } } return false; } - - static void register_direct_converter() { - get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter); - } }; -template -std::string npy_format_descriptor::value>>::format_str; -template -PyObject* npy_format_descriptor::value>>::dtype_ptr = nullptr; - #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ ::pybind11::detail::field_descriptor { \ Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index b4e6d71f2..2ef6f4d0b 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -1,11 +1,20 @@ +import re import pytest + with pytest.suppress(ImportError): import numpy as np - simple_dtype = np.dtype({'names': ['x', 'y', 'z'], - 'formats': ['?', 'u4', 'f4'], - 'offsets': [0, 4, 8]}) - packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')]) + +@pytest.fixture(scope='module') +def simple_dtype(): + return np.dtype({'names': ['x', 'y', 'z'], + 'formats': ['?', 'u4', 'f4'], + 'offsets': [0, 4, 8]}) + + +@pytest.fixture(scope='module') +def packed_dtype(): + return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')]) def assert_equal(actual, expected_data, expected_dtype): @@ -18,7 +27,7 @@ def test_format_descriptors(): with pytest.raises(RuntimeError) as excinfo: get_format_unbound() - assert 'unsupported buffer format' in str(excinfo.value) + assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) assert print_format_descriptors() == [ "T{=?:x:3x=I:y:=f:z:}", @@ -32,7 +41,7 @@ def test_format_descriptors(): @pytest.requires_numpy -def test_dtype(): +def test_dtype(simple_dtype): from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods assert print_dtypes() == [ @@ -57,7 +66,7 @@ def test_dtype(): @pytest.requires_numpy -def test_recarray(): +def test_recarray(simple_dtype, packed_dtype): from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested, print_rec_simple, print_rec_packed, print_rec_nested, create_rec_partial, create_rec_partial_nested)