mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
Merge pull request #472 from aldanor/feature/shared-dtypes
Support for sharing dtypes across extensions + public shared data API
This commit is contained in:
commit
0a9ef9c300
@ -149,6 +149,25 @@ accessed by multiple extension modules:
|
|||||||
...
|
...
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Note also that it is possible (although would rarely be required) to share arbitrary
|
||||||
|
C++ objects between extension modules at runtime. Internal library data is shared
|
||||||
|
between modules using capsule machinery [#f6]_ which can be also utilized for
|
||||||
|
storing, modifying and accessing user-defined data. Note that an extension module
|
||||||
|
will "see" other extensions' data if and only if they were built with the same
|
||||||
|
pybind11 version. Consider the following example:
|
||||||
|
|
||||||
|
.. code-block:: cpp
|
||||||
|
|
||||||
|
auto data = (MyData *) py::get_shared_data("mydata");
|
||||||
|
if (!data)
|
||||||
|
data = (MyData *) py::set_shared_data("mydata", new MyData(42));
|
||||||
|
|
||||||
|
If the above snippet was used in several separately compiled extension modules,
|
||||||
|
the first one to be imported would create a ``MyData`` instance and associate
|
||||||
|
a ``"mydata"`` key with a pointer to it. Extensions that are imported later
|
||||||
|
would be then able to access the data behind the same pointer.
|
||||||
|
|
||||||
|
.. [#f6] https://docs.python.org/3/extending/extending.html#using-capsules
|
||||||
|
|
||||||
|
|
||||||
Generating documentation using Sphinx
|
Generating documentation using Sphinx
|
||||||
|
@ -323,6 +323,7 @@ struct internals {
|
|||||||
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
|
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
|
||||||
std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
||||||
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
|
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
|
||||||
|
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
|
||||||
#if defined(WITH_THREAD)
|
#if defined(WITH_THREAD)
|
||||||
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
|
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
|
||||||
PyInterpreterState *istate = nullptr;
|
PyInterpreterState *istate = nullptr;
|
||||||
@ -427,6 +428,35 @@ inline void ignore_unused(const int *) { }
|
|||||||
|
|
||||||
NAMESPACE_END(detail)
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
|
/// Returns a named pointer that is shared among all extension modules (using the same
|
||||||
|
/// pybind11 version) running in the current interpreter. Names starting with underscores
|
||||||
|
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
|
||||||
|
inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) {
|
||||||
|
auto& internals = detail::get_internals();
|
||||||
|
auto it = internals.shared_data.find(name);
|
||||||
|
return it != internals.shared_data.end() ? it->second : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the shared data that can be later recovered by `get_shared_data()`.
|
||||||
|
inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) {
|
||||||
|
detail::get_internals().shared_data[name] = data;
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
|
||||||
|
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
|
||||||
|
/// added to the shared data under the given name and a reference to it is returned.
|
||||||
|
template<typename T> T& get_or_create_shared_data(const std::string& name) {
|
||||||
|
auto& internals = detail::get_internals();
|
||||||
|
auto it = internals.shared_data.find(name);
|
||||||
|
T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr);
|
||||||
|
if (!ptr) {
|
||||||
|
ptr = new T();
|
||||||
|
internals.shared_data[name] = ptr;
|
||||||
|
}
|
||||||
|
return *ptr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Fetch and hold an error which was already set in Python
|
/// Fetch and hold an error which was already set in Python
|
||||||
class error_already_set : public std::runtime_error {
|
class error_already_set : public std::runtime_error {
|
||||||
public:
|
public:
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <typeindex>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
# pragma warning(push)
|
# pragma warning(push)
|
||||||
@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
|
|||||||
PyObject *base;
|
PyObject *base;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct numpy_type_info {
|
||||||
|
PyObject* dtype_ptr;
|
||||||
|
std::string format_str;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct numpy_internals {
|
||||||
|
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
|
||||||
|
|
||||||
|
numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
|
||||||
|
auto it = registered_dtypes.find(std::type_index(tinfo));
|
||||||
|
if (it != registered_dtypes.end())
|
||||||
|
return &(it->second);
|
||||||
|
if (throw_if_missing)
|
||||||
|
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
|
||||||
|
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
|
||||||
|
ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline numpy_internals& get_numpy_internals() {
|
||||||
|
static numpy_internals* ptr = nullptr;
|
||||||
|
if (!ptr)
|
||||||
|
load_numpy_internals(ptr);
|
||||||
|
return *ptr;
|
||||||
|
}
|
||||||
|
|
||||||
struct npy_api {
|
struct npy_api {
|
||||||
enum constants {
|
enum constants {
|
||||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||||
@ -656,99 +690,100 @@ struct field_descriptor {
|
|||||||
dtype descr;
|
dtype descr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||||
|
const std::initializer_list<field_descriptor>& fields,
|
||||||
|
const std::type_info& tinfo, size_t itemsize,
|
||||||
|
bool (*direct_converter)(PyObject *, void *&))
|
||||||
|
{
|
||||||
|
auto& numpy_internals = get_numpy_internals();
|
||||||
|
if (numpy_internals.get_type_info(tinfo, false))
|
||||||
|
pybind11_fail("NumPy: dtype is already registered");
|
||||||
|
|
||||||
|
list names, formats, offsets;
|
||||||
|
for (auto field : fields) {
|
||||||
|
if (!field.descr)
|
||||||
|
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
|
||||||
|
field.name + "` @ " + tinfo.name());
|
||||||
|
names.append(PYBIND11_STR_TYPE(field.name));
|
||||||
|
formats.append(field.descr);
|
||||||
|
offsets.append(pybind11::int_(field.offset));
|
||||||
|
}
|
||||||
|
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
|
||||||
|
|
||||||
|
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
||||||
|
// not encoded explicitly into the format string. This will supposedly
|
||||||
|
// get fixed in v1.12; for further details, see these:
|
||||||
|
// - https://github.com/numpy/numpy/issues/7797
|
||||||
|
// - https://github.com/numpy/numpy/pull/7798
|
||||||
|
// Because of this, we won't use numpy's logic to generate buffer format
|
||||||
|
// strings and will just do it ourselves.
|
||||||
|
std::vector<field_descriptor> ordered_fields(fields);
|
||||||
|
std::sort(ordered_fields.begin(), ordered_fields.end(),
|
||||||
|
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
|
||||||
|
size_t offset = 0;
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "T{";
|
||||||
|
for (auto& field : ordered_fields) {
|
||||||
|
if (field.offset > offset)
|
||||||
|
oss << (field.offset - offset) << 'x';
|
||||||
|
// note that '=' is required to cover the case of unaligned fields
|
||||||
|
oss << '=' << field.format << ':' << field.name << ':';
|
||||||
|
offset = field.offset + field.size;
|
||||||
|
}
|
||||||
|
if (itemsize > offset)
|
||||||
|
oss << (itemsize - offset) << 'x';
|
||||||
|
oss << '}';
|
||||||
|
auto format_str = oss.str();
|
||||||
|
|
||||||
|
// Sanity check: verify that NumPy properly parses our buffer format string
|
||||||
|
auto& api = npy_api::get();
|
||||||
|
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
|
||||||
|
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
|
||||||
|
pybind11_fail("NumPy: invalid buffer descriptor!");
|
||||||
|
|
||||||
|
auto tindex = std::type_index(tinfo);
|
||||||
|
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
|
||||||
|
get_internals().direct_conversions[tindex].push_back(direct_converter);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
||||||
static PYBIND11_DESCR name() { return _("struct"); }
|
static PYBIND11_DESCR name() { return _("struct"); }
|
||||||
|
|
||||||
static pybind11::dtype dtype() {
|
static pybind11::dtype dtype() {
|
||||||
if (!dtype_ptr)
|
return object(dtype_ptr(), true);
|
||||||
pybind11_fail("NumPy: unsupported buffer format!");
|
|
||||||
return object(dtype_ptr, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string format() {
|
static std::string format() {
|
||||||
if (!dtype_ptr)
|
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
|
||||||
pybind11_fail("NumPy: unsupported buffer format!");
|
|
||||||
return format_str;
|
return format_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
|
||||||
if (dtype_ptr)
|
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
|
||||||
pybind11_fail("NumPy: dtype is already registered");
|
sizeof(T), &direct_converter);
|
||||||
|
|
||||||
list names, formats, offsets;
|
|
||||||
for (auto field : fields) {
|
|
||||||
if (!field.descr)
|
|
||||||
pybind11_fail("NumPy: unsupported field dtype");
|
|
||||||
names.append(PYBIND11_STR_TYPE(field.name));
|
|
||||||
formats.append(field.descr);
|
|
||||||
offsets.append(pybind11::int_(field.offset));
|
|
||||||
}
|
|
||||||
dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
|
|
||||||
|
|
||||||
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
|
||||||
// not encoded explicitly into the format string. This will supposedly
|
|
||||||
// get fixed in v1.12; for further details, see these:
|
|
||||||
// - https://github.com/numpy/numpy/issues/7797
|
|
||||||
// - https://github.com/numpy/numpy/pull/7798
|
|
||||||
// Because of this, we won't use numpy's logic to generate buffer format
|
|
||||||
// strings and will just do it ourselves.
|
|
||||||
std::vector<field_descriptor> ordered_fields(fields);
|
|
||||||
std::sort(ordered_fields.begin(), ordered_fields.end(),
|
|
||||||
[](const field_descriptor &a, const field_descriptor &b) {
|
|
||||||
return a.offset < b.offset;
|
|
||||||
});
|
|
||||||
size_t offset = 0;
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << "T{";
|
|
||||||
for (auto& field : ordered_fields) {
|
|
||||||
if (field.offset > offset)
|
|
||||||
oss << (field.offset - offset) << 'x';
|
|
||||||
// note that '=' is required to cover the case of unaligned fields
|
|
||||||
oss << '=' << field.format << ':' << field.name << ':';
|
|
||||||
offset = field.offset + field.size;
|
|
||||||
}
|
|
||||||
if (sizeof(T) > offset)
|
|
||||||
oss << (sizeof(T) - offset) << 'x';
|
|
||||||
oss << '}';
|
|
||||||
format_str = oss.str();
|
|
||||||
|
|
||||||
// Sanity check: verify that NumPy properly parses our buffer format string
|
|
||||||
auto& api = npy_api::get();
|
|
||||||
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1));
|
|
||||||
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
|
|
||||||
pybind11_fail("NumPy: invalid buffer descriptor!");
|
|
||||||
|
|
||||||
register_direct_converter();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::string format_str;
|
static PyObject* dtype_ptr() {
|
||||||
static PyObject* dtype_ptr;
|
static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
static bool direct_converter(PyObject *obj, void*& value) {
|
static bool direct_converter(PyObject *obj, void*& value) {
|
||||||
auto& api = npy_api::get();
|
auto& api = npy_api::get();
|
||||||
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
|
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
|
||||||
return false;
|
return false;
|
||||||
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
|
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
|
||||||
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
|
if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
|
||||||
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
|
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void register_direct_converter() {
|
|
||||||
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str;
|
|
||||||
template <typename T>
|
|
||||||
PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr;
|
|
||||||
|
|
||||||
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
|
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
|
||||||
::pybind11::detail::field_descriptor { \
|
::pybind11::detail::field_descriptor { \
|
||||||
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
|
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
|
||||||
|
@ -1,11 +1,20 @@
|
|||||||
|
import re
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
with pytest.suppress(ImportError):
|
with pytest.suppress(ImportError):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
simple_dtype = np.dtype({'names': ['x', 'y', 'z'],
|
|
||||||
'formats': ['?', 'u4', 'f4'],
|
@pytest.fixture(scope='module')
|
||||||
'offsets': [0, 4, 8]})
|
def simple_dtype():
|
||||||
packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
|
return np.dtype({'names': ['x', 'y', 'z'],
|
||||||
|
'formats': ['?', 'u4', 'f4'],
|
||||||
|
'offsets': [0, 4, 8]})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def packed_dtype():
|
||||||
|
return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
|
||||||
|
|
||||||
|
|
||||||
def assert_equal(actual, expected_data, expected_dtype):
|
def assert_equal(actual, expected_data, expected_dtype):
|
||||||
@ -18,7 +27,7 @@ def test_format_descriptors():
|
|||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
get_format_unbound()
|
get_format_unbound()
|
||||||
assert 'unsupported buffer format' in str(excinfo.value)
|
assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
|
||||||
|
|
||||||
assert print_format_descriptors() == [
|
assert print_format_descriptors() == [
|
||||||
"T{=?:x:3x=I:y:=f:z:}",
|
"T{=?:x:3x=I:y:=f:z:}",
|
||||||
@ -32,7 +41,7 @@ def test_format_descriptors():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.requires_numpy
|
@pytest.requires_numpy
|
||||||
def test_dtype():
|
def test_dtype(simple_dtype):
|
||||||
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods
|
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods
|
||||||
|
|
||||||
assert print_dtypes() == [
|
assert print_dtypes() == [
|
||||||
@ -57,7 +66,7 @@ def test_dtype():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.requires_numpy
|
@pytest.requires_numpy
|
||||||
def test_recarray():
|
def test_recarray(simple_dtype, packed_dtype):
|
||||||
from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
|
from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
|
||||||
print_rec_simple, print_rec_packed, print_rec_nested,
|
print_rec_simple, print_rec_packed, print_rec_nested,
|
||||||
create_rec_partial, create_rec_partial_nested)
|
create_rec_partial, create_rec_partial_nested)
|
||||||
|
Loading…
Reference in New Issue
Block a user