From 1b05ce5bc062d15940b10451fcbdc4bf436dc584 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Tue, 9 Aug 2016 17:57:59 -0400 Subject: [PATCH] Track registered instances that share a pointer address The pointer to the first member of a class instance is the same as the pointer to instance itself; pybind11 has some workarounds for this to not track registered instances that have a registered parent with the same address. This doesn't work everywhere, however: issue #328 is a failure of this for a mutator operator which resolves its argument to the parent rather than the child, as is needed in #328. This commit resolves the issue (and restores tracking of same-address instances) by changing registered_instances from an unordered_map to an unordered_multimap that allows duplicate instances for the same pointer to be recorded, then resolves these differences by checking the type of each matched instance when looking up an instance. (A unordered_multimap seems cleaner for this than a unordered_map or similar because, the vast majority of the time, the instance will be unique). --- example/issues.cpp | 13 +++++++++++++ example/issues.py | 17 +++++++++++++++++ example/issues.ref | 18 ++++++++++++++++++ include/pybind11/cast.h | 36 ++++++++++++++++++++---------------- include/pybind11/common.h | 6 +++--- include/pybind11/pybind11.h | 34 ++++++++++++++++++++++++---------- 6 files changed, 95 insertions(+), 29 deletions(-) diff --git a/example/issues.cpp b/example/issues.cpp index 55fc3f3cb..66934ee06 100644 --- a/example/issues.cpp +++ b/example/issues.cpp @@ -9,6 +9,7 @@ #include "example.h" #include +#include PYBIND11_DECLARE_HOLDER_TYPE(T, std::shared_ptr); @@ -157,4 +158,16 @@ void init_issues(py::module &m) { }) ; + // Issue #328: first member in a class can't be used in operators +#define TRACKERS(CLASS) CLASS() { std::cout << #CLASS "@" << this << " constructor\n"; } \ + ~CLASS() { std::cout << #CLASS "@" << this << " destructor\n"; } + struct NestA { int value = 3; NestA& operator+=(int i) { value += i; return *this; } TRACKERS(NestA) }; + struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) }; + struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) }; + py::class_(m2, "NestA").def(py::init<>()).def(py::self += int()); + py::class_(m2, "NestB").def(py::init<>()).def(py::self -= int()).def_readwrite("a", &NestB::a); + py::class_(m2, "NestC").def(py::init<>()).def(py::self *= int()).def_readwrite("b", &NestC::b); + m2.def("print_NestA", [](const NestA &a) { std::cout << a.value << std::endl; }); + m2.def("print_NestB", [](const NestB &b) { std::cout << b.value << std::endl; }); + m2.def("print_NestC", [](const NestC &c) { std::cout << c.value << std::endl; }); } diff --git a/example/issues.py b/example/issues.py index 716a1b2fd..12b46b7f5 100644 --- a/example/issues.py +++ b/example/issues.py @@ -11,6 +11,7 @@ from example.issues import ElementList, ElementA, print_element from example.issues import expect_float, expect_int from example.issues import A, call_f from example.issues import StrIssue +from example.issues import NestA, NestB, NestC, print_NestA, print_NestB, print_NestC import gc print_cchar("const char *") @@ -78,3 +79,19 @@ try: print(StrIssue("no", "such", "constructor")) except TypeError as e: print("Failed as expected: " + str(e)) + +a = NestA() +b = NestB() +c = NestC() +a += 10 +b.a += 100 +c.b.a += 1000 +b -= 1 +c.b -= 3 +c *= 7 +print_NestA(a) +print_NestA(b.a) +print_NestA(c.b.a) +print_NestB(b) +print_NestB(c.b) +print_NestC(c) diff --git a/example/issues.ref b/example/issues.ref index acb1ed08e..f25ab298b 100644 --- a/example/issues.ref +++ b/example/issues.ref @@ -24,3 +24,21 @@ Failed as expected: Incompatible constructor arguments. The following argument t 1. example.issues.StrIssue(arg0: int) 2. example.issues.StrIssue() Invoked with: no, such, constructor +NestA@0x1152940 constructor +NestA@0x11f9350 constructor +NestB@0x11f9350 constructor +NestA@0x112d0d0 constructor +NestB@0x112d0d0 constructor +NestC@0x112d0d0 constructor +13 +103 +1003 +3 +1 +35 +NestC@0x112d0d0 destructor +NestB@0x112d0d0 destructor +NestA@0x112d0d0 destructor +NestB@0x11f9350 destructor +NestA@0x11f9350 destructor +NestA@0x1152940 destructor diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 34dbb28e1..ad8f2759d 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -119,12 +119,15 @@ PYBIND11_NOINLINE inline std::string error_string() { return errorString; } -PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr) { - auto instances = get_internals().registered_instances; - auto it = instances.find(ptr); - if (it == instances.end()) - return handle(); - return handle((PyObject *) it->second); +PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { + auto &instances = get_internals().registered_instances; + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + auto instance_type = detail::get_type_info(Py_TYPE(it->second), false); + if (instance_type && instance_type == type) + return handle((PyObject *) it->second); + } + return handle(); } inline PyThreadState *get_thread_state_unchecked() { @@ -174,14 +177,7 @@ public: if (src == nullptr) return handle(Py_None).inc_ref(); - // avoid an issue with internal references matching their parent's address - bool dont_cache = policy == return_value_policy::reference_internal && - parent && ((instance *) parent.ptr())->value == (void *) src; - - auto& internals = get_internals(); - auto it_instance = internals.registered_instances.find(src); - if (it_instance != internals.registered_instances.end() && !dont_cache) - return handle((PyObject *) it_instance->second).inc_ref(); + auto &internals = get_internals(); auto it = internals.registered_types_cpp.find(std::type_index(*type_info)); if (it == internals.registered_types_cpp.end()) { @@ -198,6 +194,14 @@ public: } auto tinfo = (const detail::type_info *) it->second; + + auto it_instances = internals.registered_instances.equal_range(src); + for (auto it = it_instances.first; it != it_instances.second; ++it) { + auto instance_type = detail::get_type_info(Py_TYPE(it->second), false); + if (instance_type && instance_type == tinfo) + return handle((PyObject *) it->second).inc_ref(); + } + object inst(PyType_GenericAlloc(tinfo->type, 0), false); auto wrapper = (instance *) inst.ptr(); @@ -229,8 +233,8 @@ public: } tinfo->init_holder(inst.ptr(), existing_holder); - if (!dont_cache) - internals.registered_instances[wrapper->value] = inst.ptr(); + + internals.registered_instances.emplace(wrapper->value, inst.ptr()); return inst.release(); } diff --git a/include/pybind11/common.h b/include/pybind11/common.h index a278a9f0e..c0d248250 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -266,9 +266,9 @@ struct overload_hash { /// Internal data struture used to track registered instances and types struct internals { - std::unordered_map registered_types_cpp; // std::type_index -> type_info - std::unordered_map registered_types_py; // PyTypeObject* -> type_info - std::unordered_map registered_instances; // void * -> PyObject* + std::unordered_map registered_types_cpp; // std::type_index -> type_info + std::unordered_map registered_types_py; // PyTypeObject* -> type_info + std::unordered_multimap registered_instances; // void * -> PyObject* std::unordered_set, overload_hash> inactive_overload_cache; std::forward_list registered_exception_translators; #if defined(WITH_THREAD) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 0016f3718..632f955e0 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -731,20 +731,26 @@ protected: self->owned = true; self->parent = nullptr; self->constructed = false; - detail::get_internals().registered_instances[self->value] = (PyObject *) self; + detail::get_internals().registered_instances.emplace(self->value, (PyObject *) self); return (PyObject *) self; } static void dealloc(instance *self) { if (self->value) { - bool dont_cache = self->parent && ((instance *) self->parent)->value == self->value; - if (!dont_cache) { // avoid an issue with internal references matching their parent's address - auto ®istered_instances = detail::get_internals().registered_instances; - auto it = registered_instances.find(self->value); - if (it == registered_instances.end()) - pybind11_fail("generic_type::dealloc(): Tried to deallocate unregistered instance!"); - registered_instances.erase(it); + auto instance_type = Py_TYPE(self); + auto ®istered_instances = detail::get_internals().registered_instances; + auto range = registered_instances.equal_range(self->value); + bool found = false; + for (auto it = range.first; it != range.second; ++it) { + if (instance_type == Py_TYPE(it->second)) { + registered_instances.erase(it); + found = true; + break; + } } + if (!found) + pybind11_fail("generic_type::dealloc(): Tried to deallocate unregistered instance!"); + Py_XDECREF(self->parent); if (self->weakrefs) PyObject_ClearWeakRefs((PyObject *) self); @@ -1316,8 +1322,8 @@ class gil_scoped_acquire { }; class gil_scoped_release { }; #endif -inline function get_overload(const void *this_ptr, const char *name) { - handle py_object = detail::get_object_handle(this_ptr); +inline function get_type_overload(const void *this_ptr, const detail::type_info *this_type, const char *name) { + handle py_object = detail::get_object_handle(this_ptr, this_type); if (!py_object) return function(); handle type = py_object.get_type(); @@ -1348,6 +1354,14 @@ inline function get_overload(const void *this_ptr, const char *name) { return overload; } +template function get_overload(const T *this_ptr, const char *name) { + auto &cpp_types = detail::get_internals().registered_types_cpp; + auto it = cpp_types.find(typeid(T)); + if (it == cpp_types.end()) + return function(); + return get_type_overload(this_ptr, (const detail::type_info *) it->second, name); +} + #define PYBIND11_OVERLOAD_INT(ret_type, name, ...) { \ pybind11::gil_scoped_acquire gil; \ pybind11::function overload = pybind11::get_overload(this, name); \