From 6873c202b3fd2e2264b2566d60cf079ee9322264 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Mon, 24 Oct 2016 21:58:22 -0400 Subject: [PATCH] Prevent overwriting previous declarations Currently pybind11 doesn't check when you define a new object (e.g. a class, function, or exception) that overwrites an existing one. If the thing being overwritten is a class, this leads to a segfault (because pybind still thinks the type is defined, even though Python no longer has the type). In other cases this is harmless (e.g. replacing a function with an exception), but even in that case it's most likely a bug. This code doesn't prevent you from actively doing something harmful, like deliberately overwriting a previous definition, but detects overwriting with a run-time error if it occurs in the standard class/function/exception/def registration interfaces. All of the additions are in non-template code; the result is actually a tiny decrease in .so size compared to master without the new test code (977304 to 977272 bytes), and about 4K higher with the new tests. --- include/pybind11/pybind11.h | 46 +++++++++++++++++++++++-------- tests/test_issues.cpp | 54 +++++++++++++++++++++++++++++++++++++ tests/test_issues.py | 5 ++++ 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 114ae971a..989219611 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -253,13 +253,19 @@ protected: #endif detail::function_record *chain = nullptr, *chain_start = rec; - if (rec->sibling && PyCFunction_Check(rec->sibling.ptr())) { - capsule rec_capsule(PyCFunction_GetSelf(rec->sibling.ptr()), true); - chain = (detail::function_record *) rec_capsule; - /* Never append a method to an overload chain of a parent class; - instead, hide the parent's overloads in this case */ - if (chain->class_ != rec->class_) - chain = nullptr; + if (rec->sibling) { + if (PyCFunction_Check(rec->sibling.ptr())) { + capsule rec_capsule(PyCFunction_GetSelf(rec->sibling.ptr()), true); + chain = (detail::function_record *) rec_capsule; + /* Never append a method to an overload chain of a parent class; + instead, hide the parent's overloads in this case */ + if (chain->class_ != rec->class_) + chain = nullptr; + } + // Don't trigger for things like the default __init__, which are wrapper_descriptors that we are intentionally replacing + else if (!rec->sibling.is_none() && rec->name[0] != '_') + pybind11_fail("Cannot overload existing non-function object \"" + std::string(rec->name) + + "\" with a function of the same name"); } if (!chain) { @@ -546,8 +552,9 @@ public: module &def(const char *name_, Func &&f, const Extra& ... extra) { cpp_function func(std::forward(f), name(name_), scope(*this), sibling(getattr(*this, name_, none())), extra...); - /* PyModule_AddObject steals a reference to 'func' */ - PyModule_AddObject(ptr(), name_, func.inc_ref().ptr()); + // NB: allow overwriting here because cpp_function sets up a chain with the intention of + // overwriting (and has already checked internally that it isn't overwriting non-functions). + add_object(name_, func, true /* overwrite */); return *this; } @@ -567,6 +574,20 @@ public: throw import_error("Module \"" + std::string(name) + "\" not found!"); return module(obj, false); } + + // Adds an object to the module using the given name. Throws if an object with the given name + // already exists. + // + // overwrite should almost always be false: attempting to overwrite objects that pybind11 has + // established will, in most cases, break things. + PYBIND11_NOINLINE void add_object(const char *name, object &obj, bool overwrite = false) { + if (!overwrite && hasattr(*this, name)) + pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" + + std::string(name) + "\""); + + obj.inc_ref(); // PyModule_AddObject() steals a reference + PyModule_AddObject(ptr(), name, obj.ptr()); + } }; NAMESPACE_BEGIN(detail) @@ -614,6 +635,10 @@ protected: object name(PYBIND11_FROM_STRING(rec->name), false); object scope_module; if (rec->scope) { + if (hasattr(rec->scope, rec->name)) + pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec->name) + + "\": an object with that name is already defined"); + if (hasattr(rec->scope, "__module__")) { scope_module = rec->scope.attr("__module__"); } else if (hasattr(rec->scope, "__name__")) { @@ -1357,8 +1382,7 @@ public: + std::string(".") + name; char* exception_name = const_cast(full_name.c_str()); m_ptr = PyErr_NewException(exception_name, base, NULL); - inc_ref(); // PyModule_AddObject() steals a reference - PyModule_AddObject(m.ptr(), name.c_str(), m_ptr); + m.add_object(name.c_str(), *this); } // Sets the current python exception to this exception object with the given message diff --git a/tests/test_issues.cpp b/tests/test_issues.cpp index 29c4057f1..f5467cb14 100644 --- a/tests/test_issues.cpp +++ b/tests/test_issues.cpp @@ -36,6 +36,18 @@ OpTest2 operator+(const OpTest2 &, const OpTest1 &) { return OpTest2(); } +// #461 +class Dupe1 { +public: + Dupe1(int v) : v_{v} {} + int get_value() const { return v_; } +private: + int v_; +}; +class Dupe2 {}; +class Dupe3 {}; +class DupeException : public std::runtime_error {}; + void init_issues(py::module &m) { py::module m2 = m.def_submodule("issues"); @@ -237,7 +249,49 @@ void init_issues(py::module &m) { static std::vector list = { 1, 2, 3 }; m2.def("make_iterator_1", []() { return py::make_iterator(list); }); m2.def("make_iterator_2", []() { return py::make_iterator(list); }); + + static std::vector nothrows; + // Issue 461: registering two things with the same name: + py::class_(m2, "Dupe1") + .def("get_value", &Dupe1::get_value) + ; + m2.def("dupe1_factory", [](int v) { return new Dupe1(v); }); + + py::class_(m2, "Dupe2"); + py::exception(m2, "DupeException"); + + try { + m2.def("Dupe1", [](int v) { return new Dupe1(v); }); + nothrows.emplace_back("Dupe1"); + } + catch (std::runtime_error &) {} + try { + py::class_(m2, "dupe1_factory"); + nothrows.emplace_back("dupe1_factory"); + } + catch (std::runtime_error &) {} + try { + py::exception(m2, "Dupe2"); + nothrows.emplace_back("Dupe2"); + } + catch (std::runtime_error &) {} + try { + m2.def("DupeException", []() { return 30; }); + nothrows.emplace_back("DupeException1"); + } + catch (std::runtime_error &) {} + try { + py::class_(m2, "DupeException"); + nothrows.emplace_back("DupeException2"); + } + catch (std::runtime_error &) {} + m2.def("dupe_exception_failures", []() { + py::list l; + for (auto &e : nothrows) l.append(py::cast(e)); + return l; + }); } + // MSVC workaround: trying to use a lambda here crashes MSCV test_initializer issues(&init_issues); diff --git a/tests/test_issues.py b/tests/test_issues.py index e2ab0b45c..cf645ef9f 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -182,3 +182,8 @@ def test_iterator_rvpolicy(): assert list(make_iterator_1()) == [1, 2, 3] assert list(make_iterator_2()) == [1, 2, 3] assert(type(make_iterator_1()) != type(make_iterator_2())) + +def test_dupe_assignment(): + """ Issue 461: overwriting a class with a function """ + from pybind11_tests.issues import dupe_exception_failures + assert dupe_exception_failures() == []