Require existing typeinfo for direct conversions

This avoid a hashmap lookup since the pointer to the list of
direct converters is now cached in the typeinfo.
This commit is contained in:
Ivan Smirnov 2016-10-23 15:27:13 +01:00
parent 7edd72db24
commit a6e6a8b108
5 changed files with 25 additions and 35 deletions

View File

@ -26,6 +26,7 @@ struct type_info {
void (*init_holder)(PyObject *, const void *);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
/** A simple type never occurs as a (direct or indirect) parent
@ -157,8 +158,7 @@ inline void keep_alive_impl(handle nurse, handle patient);
class type_caster_generic {
public:
PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
: typeinfo(get_type_info(type_info, false)),
direct_conversions(get_internals().direct_conversions[std::type_index(type_info)]) { }
: typeinfo(get_type_info(type_info, false)) { }
PYBIND11_NOINLINE bool load(handle src, bool convert) {
if (!src)
@ -167,14 +167,12 @@ public:
}
bool load(handle src, bool convert, PyTypeObject *tobj) {
if (!src)
if (!src || !typeinfo)
return false;
if (src.is_none()) {
value = nullptr;
return true;
}
if (!typeinfo)
return load_direct(src, convert);
if (typeinfo->simple_type) { /* Case 1: no multiple inheritance etc. involved */
/* Check if we can safely perform a reinterpret-style cast */
@ -218,9 +216,12 @@ public:
if (load(temp, false))
return true;
}
for (auto &converter : *typeinfo->direct_conversions) {
if (converter(src.ptr(), value))
return true;
}
}
return load_direct(src, convert);
return false;
}
PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent,
@ -298,19 +299,8 @@ public:
protected:
const type_info *typeinfo = nullptr;
const std::vector<bool (*)(PyObject *, void *&)>& direct_conversions;
void *value = nullptr;
object temp;
bool load_direct(handle src, bool convert) {
if (convert) {
for (auto& converter : direct_conversions) {
if (converter(src.ptr(), value))
return true;
}
}
return false;
}
};
/* Determine suitable casting operator */

View File

@ -721,20 +721,21 @@ private:
static std::string format_str;
static PyObject* dtype_ptr;
static void register_direct_converter() {
auto converter = [=](PyObject *obj, void*& value) {
auto& api = npy_api::get();
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
return false;
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
return true;
}
}
static bool direct_converter(PyObject *obj, void*& value) {
auto& api = npy_api::get();
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
return false;
};
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(converter);
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
return true;
}
}
return false;
}
static void register_direct_converter() {
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
}
};

View File

@ -672,6 +672,7 @@ protected:
tinfo->type = (PyTypeObject *) type;
tinfo->type_size = rec->type_size;
tinfo->init_holder = rec->init_holder;
tinfo->direct_conversions = &internals.direct_conversions[tindex];
internals.registered_types_cpp[tindex] = tinfo;
internals.registered_types_py[type] = tinfo;

View File

@ -309,11 +309,9 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
// ... or after...
// ... or after
py::class_<PackedStruct>(m, "PackedStruct");
// ... or not at all
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
m.def("create_rec_packed", &create_recarray<PackedStruct>);
m.def("create_rec_nested", &create_nested);

View File

@ -190,7 +190,7 @@ def test_scalar_conversion():
for i, func in enumerate(funcs):
for j, arr in enumerate(arrays):
if i == j:
if i == j and i < 2:
assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
else:
with pytest.raises(TypeError) as excinfo: