mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Merge pull request #453 from aldanor/feature/numpy-scalars
NumPy scalars to ctypes conversion support
This commit is contained in:
commit
dd9bd7778f
@ -26,6 +26,7 @@ struct type_info {
|
||||
void (*init_holder)(PyObject *, const void *);
|
||||
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
|
||||
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
|
||||
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
|
||||
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
|
||||
void *get_buffer_data = nullptr;
|
||||
/** A simple type never occurs as a (direct or indirect) parent
|
||||
@ -90,7 +91,8 @@ PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) {
|
||||
} while (true);
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_info &tp, bool throw_if_missing) {
|
||||
PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_info &tp,
|
||||
bool throw_if_missing = false) {
|
||||
auto &types = get_internals().registered_types_cpp;
|
||||
|
||||
auto it = types.find(std::type_index(tp));
|
||||
@ -157,7 +159,7 @@ inline void keep_alive_impl(handle nurse, handle patient);
|
||||
class type_caster_generic {
|
||||
public:
|
||||
PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
|
||||
: typeinfo(get_type_info(type_info, false)) { }
|
||||
: typeinfo(get_type_info(type_info)) { }
|
||||
|
||||
PYBIND11_NOINLINE bool load(handle src, bool convert) {
|
||||
if (!src)
|
||||
@ -215,6 +217,10 @@ public:
|
||||
if (load(temp, false))
|
||||
return true;
|
||||
}
|
||||
for (auto &converter : *typeinfo->direct_conversions) {
|
||||
if (converter(src.ptr(), value))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -321,6 +321,7 @@ struct internals {
|
||||
std::unordered_map<const void *, void*> registered_types_py; // PyTypeObject* -> type_info
|
||||
std::unordered_multimap<const void *, void*> registered_instances; // void * -> PyObject*
|
||||
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
|
||||
std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
||||
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
|
||||
#if defined(WITH_THREAD)
|
||||
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
|
||||
|
@ -63,6 +63,14 @@ struct PyArray_Proxy {
|
||||
int flags;
|
||||
};
|
||||
|
||||
struct PyVoidScalarObject_Proxy {
|
||||
PyObject_VAR_HEAD
|
||||
char *obval;
|
||||
PyArrayDescr_Proxy *descr;
|
||||
int flags;
|
||||
PyObject *base;
|
||||
};
|
||||
|
||||
struct npy_api {
|
||||
enum constants {
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
@ -103,7 +111,9 @@ struct npy_api {
|
||||
PyObject *(*PyArray_DescrNewFromType_)(int);
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyTypeObject *PyVoidArrType_Type_;
|
||||
PyTypeObject *PyArrayDescr_Type_;
|
||||
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||
@ -114,7 +124,9 @@ private:
|
||||
enum functions {
|
||||
API_PyArray_Type = 2,
|
||||
API_PyArrayDescr_Type = 3,
|
||||
API_PyVoidArrType_Type = 39,
|
||||
API_PyArray_DescrFromType = 45,
|
||||
API_PyArray_DescrFromScalar = 57,
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
@ -136,8 +148,10 @@ private:
|
||||
npy_api api;
|
||||
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
||||
DECL_NPY_API(PyArray_Type);
|
||||
DECL_NPY_API(PyVoidArrType_Type);
|
||||
DECL_NPY_API(PyArrayDescr_Type);
|
||||
DECL_NPY_API(PyArray_DescrFromType);
|
||||
DECL_NPY_API(PyArray_DescrFromScalar);
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
@ -658,6 +672,9 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
||||
}
|
||||
|
||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
||||
if (dtype_ptr)
|
||||
pybind11_fail("NumPy: dtype is already registered");
|
||||
|
||||
list names, formats, offsets;
|
||||
for (auto field : fields) {
|
||||
if (!field.descr)
|
||||
@ -700,11 +717,30 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
||||
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1));
|
||||
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
|
||||
pybind11_fail("NumPy: invalid buffer descriptor!");
|
||||
|
||||
register_direct_converter();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string format_str;
|
||||
static PyObject* dtype_ptr;
|
||||
|
||||
static bool direct_converter(PyObject *obj, void*& value) {
|
||||
auto& api = npy_api::get();
|
||||
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
|
||||
return false;
|
||||
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
|
||||
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
|
||||
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void register_direct_converter() {
|
||||
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -180,8 +180,6 @@ protected:
|
||||
a.descr = strdup(a.value.attr("__repr__")().cast<std::string>().c_str());
|
||||
}
|
||||
|
||||
auto const ®istered_types = detail::get_internals().registered_types_cpp;
|
||||
|
||||
/* Generate a proper function signature */
|
||||
std::string signature;
|
||||
size_t type_depth = 0, char_index = 0, type_index = 0, arg_index = 0;
|
||||
@ -216,9 +214,8 @@ protected:
|
||||
const std::type_info *t = types[type_index++];
|
||||
if (!t)
|
||||
pybind11_fail("Internal error while parsing type signature (1)");
|
||||
auto it = registered_types.find(std::type_index(*t));
|
||||
if (it != registered_types.end()) {
|
||||
signature += ((const detail::type_info *) it->second)->type->tp_name;
|
||||
if (auto tinfo = detail::get_type_info(*t)) {
|
||||
signature += tinfo->type->tp_name;
|
||||
} else {
|
||||
std::string tname(t->name());
|
||||
detail::clean_type_id(tname);
|
||||
@ -610,8 +607,7 @@ protected:
|
||||
auto &internals = get_internals();
|
||||
auto tindex = std::type_index(*(rec->type));
|
||||
|
||||
if (internals.registered_types_cpp.find(tindex) !=
|
||||
internals.registered_types_cpp.end())
|
||||
if (get_type_info(*(rec->type)))
|
||||
pybind11_fail("generic_type: type \"" + std::string(rec->name) +
|
||||
"\" is already registered!");
|
||||
|
||||
@ -672,6 +668,7 @@ protected:
|
||||
tinfo->type = (PyTypeObject *) type;
|
||||
tinfo->type_size = rec->type_size;
|
||||
tinfo->init_holder = rec->init_holder;
|
||||
tinfo->direct_conversions = &internals.direct_conversions[tindex];
|
||||
internals.registered_types_cpp[tindex] = tinfo;
|
||||
internals.registered_types_py[type] = tinfo;
|
||||
|
||||
@ -1333,11 +1330,11 @@ template <typename InputType, typename OutputType> void implicitly_convertible()
|
||||
PyErr_Clear();
|
||||
return result;
|
||||
};
|
||||
auto ®istered_types = detail::get_internals().registered_types_cpp;
|
||||
auto it = registered_types.find(std::type_index(typeid(OutputType)));
|
||||
if (it == registered_types.end())
|
||||
|
||||
if (auto tinfo = detail::get_type_info(typeid(OutputType)))
|
||||
tinfo->implicit_conversions.push_back(implicit_caster);
|
||||
else
|
||||
pybind11_fail("implicitly_convertible: Unable to find type " + type_id<OutputType>());
|
||||
((detail::type_info *) it->second)->implicit_conversions.push_back(implicit_caster);
|
||||
}
|
||||
|
||||
template <typename ExceptionTranslator>
|
||||
@ -1589,11 +1586,8 @@ inline function get_type_overload(const void *this_ptr, const detail::type_info
|
||||
}
|
||||
|
||||
template <class T> function get_overload(const T *this_ptr, const char *name) {
|
||||
auto &cpp_types = detail::get_internals().registered_types_cpp;
|
||||
auto it = cpp_types.find(typeid(T));
|
||||
if (it == cpp_types.end())
|
||||
return function();
|
||||
return get_type_overload(this_ptr, (const detail::type_info *) it->second, name);
|
||||
auto tinfo = detail::get_type_info(typeid(T));
|
||||
return tinfo ? get_type_overload(this_ptr, tinfo, name) : function();
|
||||
}
|
||||
|
||||
#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
|
||||
|
@ -298,6 +298,9 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
return;
|
||||
}
|
||||
|
||||
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
|
||||
py::class_<SimpleStruct>(m, "SimpleStruct");
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
|
||||
@ -306,6 +309,9 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||
|
||||
// ... or after
|
||||
py::class_<PackedStruct>(m, "PackedStruct");
|
||||
|
||||
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
|
||||
m.def("create_rec_packed", &create_recarray<PackedStruct>);
|
||||
m.def("create_rec_nested", &create_nested);
|
||||
@ -324,6 +330,10 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
m.def("test_array_ctors", &test_array_ctors);
|
||||
m.def("test_dtype_ctors", &test_dtype_ctors);
|
||||
m.def("test_dtype_methods", &test_dtype_methods);
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
|
||||
m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
|
||||
m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });
|
||||
m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z); });
|
||||
});
|
||||
|
||||
#undef PYBIND11_PACKED
|
||||
|
@ -174,3 +174,34 @@ def test_signature(doc):
|
||||
from pybind11_tests import create_rec_nested
|
||||
|
||||
assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_scalar_conversion():
|
||||
from pybind11_tests import (create_rec_simple, f_simple,
|
||||
create_rec_packed, f_packed,
|
||||
create_rec_nested, f_nested,
|
||||
create_enum_array)
|
||||
|
||||
n = 3
|
||||
arrays = [create_rec_simple(n), create_rec_packed(n),
|
||||
create_rec_nested(n), create_enum_array(n)]
|
||||
funcs = [f_simple, f_packed, f_nested]
|
||||
|
||||
for i, func in enumerate(funcs):
|
||||
for j, arr in enumerate(arrays):
|
||||
if i == j and i < 2:
|
||||
assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
|
||||
else:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
func(arr[0])
|
||||
assert 'incompatible function arguments' in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_register_dtype():
|
||||
from pybind11_tests import register_dtype
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
register_dtype()
|
||||
assert 'dtype is already registered' in str(excinfo.value)
|
||||
|
Loading…
Reference in New Issue
Block a user