From a6e6a8b108f49f8e990045e020e2058b15dfcf5b Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Sun, 23 Oct 2016 15:27:13 +0100 Subject: [PATCH] Require existing typeinfo for direct conversions This avoid a hashmap lookup since the pointer to the list of direct converters is now cached in the typeinfo. --- include/pybind11/cast.h | 26 ++++++++------------------ include/pybind11/numpy.h | 27 ++++++++++++++------------- include/pybind11/pybind11.h | 1 + tests/test_numpy_dtypes.cpp | 4 +--- tests/test_numpy_dtypes.py | 2 +- 5 files changed, 25 insertions(+), 35 deletions(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 1d6f605dd..1b82d44f4 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -26,6 +26,7 @@ struct type_info { void (*init_holder)(PyObject *, const void *); std::vector implicit_conversions; std::vector> implicit_casts; + std::vector *direct_conversions; buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; void *get_buffer_data = nullptr; /** A simple type never occurs as a (direct or indirect) parent @@ -157,8 +158,7 @@ inline void keep_alive_impl(handle nurse, handle patient); class type_caster_generic { public: PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) - : typeinfo(get_type_info(type_info, false)), - direct_conversions(get_internals().direct_conversions[std::type_index(type_info)]) { } + : typeinfo(get_type_info(type_info, false)) { } PYBIND11_NOINLINE bool load(handle src, bool convert) { if (!src) @@ -167,14 +167,12 @@ public: } bool load(handle src, bool convert, PyTypeObject *tobj) { - if (!src) + if (!src || !typeinfo) return false; if (src.is_none()) { value = nullptr; return true; } - if (!typeinfo) - return load_direct(src, convert); if (typeinfo->simple_type) { /* Case 1: no multiple inheritance etc. involved */ /* Check if we can safely perform a reinterpret-style cast */ @@ -218,9 +216,12 @@ public: if (load(temp, false)) return true; } + for (auto &converter : *typeinfo->direct_conversions) { + if (converter(src.ptr(), value)) + return true; + } } - - return load_direct(src, convert); + return false; } PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, @@ -298,19 +299,8 @@ public: protected: const type_info *typeinfo = nullptr; - const std::vector& direct_conversions; void *value = nullptr; object temp; - - bool load_direct(handle src, bool convert) { - if (convert) { - for (auto& converter : direct_conversions) { - if (converter(src.ptr(), value)) - return true; - } - } - return false; - } }; /* Determine suitable casting operator */ diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 602d703cb..2db3de273 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -721,20 +721,21 @@ private: static std::string format_str; static PyObject* dtype_ptr; - static void register_direct_converter() { - auto converter = [=](PyObject *obj, void*& value) { - auto& api = npy_api::get(); - if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) - return false; - if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { - if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { - value = ((PyVoidScalarObject_Proxy *) obj)->obval; - return true; - } - } + static bool direct_converter(PyObject *obj, void*& value) { + auto& api = npy_api::get(); + if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) return false; - }; - get_internals().direct_conversions[std::type_index(typeid(T))].push_back(converter); + if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { + if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { + value = ((PyVoidScalarObject_Proxy *) obj)->obval; + return true; + } + } + return false; + } + + static void register_direct_converter() { + get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter); } }; diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 752f77564..ffa25801b 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -672,6 +672,7 @@ protected: tinfo->type = (PyTypeObject *) type; tinfo->type_size = rec->type_size; tinfo->init_holder = rec->init_holder; + tinfo->direct_conversions = &internals.direct_conversions[tindex]; internals.registered_types_cpp[tindex] = tinfo; internals.registered_types_py[type] = tinfo; diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index f85195353..40aca0c3c 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -309,11 +309,9 @@ test_initializer numpy_dtypes([](py::module &m) { PYBIND11_NUMPY_DTYPE(StringStruct, a, b); PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); - // ... or after... + // ... or after py::class_(m, "PackedStruct"); - // ... or not at all - m.def("create_rec_simple", &create_recarray); m.def("create_rec_packed", &create_recarray); m.def("create_rec_nested", &create_nested); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 0503ef1a9..47d7c3bd7 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -190,7 +190,7 @@ def test_scalar_conversion(): for i, func in enumerate(funcs): for j, arr in enumerate(arrays): - if i == j: + if i == j and i < 2: assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)] else: with pytest.raises(TypeError) as excinfo: