Prevent overwriting previous declarations

Currently pybind11 doesn't check when you define a new object (e.g. a
class, function, or exception) that overwrites an existing one.  If the
thing being overwritten is a class, this leads to a segfault (because
pybind still thinks the type is defined, even though Python no longer
has the type).  In other cases this is harmless (e.g. replacing a
function with an exception), but even in that case it's most likely a
bug.

This code doesn't prevent you from actively doing something harmful,
like deliberately overwriting a previous definition, but detects
overwriting with a run-time error if it occurs in the standard
class/function/exception/def registration interfaces.

All of the additions are in non-template code; the result is actually a
tiny decrease in .so size compared to master without the new test code
(977304 to 977272 bytes), and about 4K higher with the new tests.
This commit is contained in:
Jason Rhinelander 2016-10-24 21:58:22 -04:00
parent dd9bd7778f
commit 6873c202b3
3 changed files with 94 additions and 11 deletions

View File

@ -253,13 +253,19 @@ protected:
#endif #endif
detail::function_record *chain = nullptr, *chain_start = rec; detail::function_record *chain = nullptr, *chain_start = rec;
if (rec->sibling && PyCFunction_Check(rec->sibling.ptr())) { if (rec->sibling) {
capsule rec_capsule(PyCFunction_GetSelf(rec->sibling.ptr()), true); if (PyCFunction_Check(rec->sibling.ptr())) {
chain = (detail::function_record *) rec_capsule; capsule rec_capsule(PyCFunction_GetSelf(rec->sibling.ptr()), true);
/* Never append a method to an overload chain of a parent class; chain = (detail::function_record *) rec_capsule;
instead, hide the parent's overloads in this case */ /* Never append a method to an overload chain of a parent class;
if (chain->class_ != rec->class_) instead, hide the parent's overloads in this case */
chain = nullptr; if (chain->class_ != rec->class_)
chain = nullptr;
}
// Don't trigger for things like the default __init__, which are wrapper_descriptors that we are intentionally replacing
else if (!rec->sibling.is_none() && rec->name[0] != '_')
pybind11_fail("Cannot overload existing non-function object \"" + std::string(rec->name) +
"\" with a function of the same name");
} }
if (!chain) { if (!chain) {
@ -546,8 +552,9 @@ public:
module &def(const char *name_, Func &&f, const Extra& ... extra) { module &def(const char *name_, Func &&f, const Extra& ... extra) {
cpp_function func(std::forward<Func>(f), name(name_), scope(*this), cpp_function func(std::forward<Func>(f), name(name_), scope(*this),
sibling(getattr(*this, name_, none())), extra...); sibling(getattr(*this, name_, none())), extra...);
/* PyModule_AddObject steals a reference to 'func' */ // NB: allow overwriting here because cpp_function sets up a chain with the intention of
PyModule_AddObject(ptr(), name_, func.inc_ref().ptr()); // overwriting (and has already checked internally that it isn't overwriting non-functions).
add_object(name_, func, true /* overwrite */);
return *this; return *this;
} }
@ -567,6 +574,20 @@ public:
throw import_error("Module \"" + std::string(name) + "\" not found!"); throw import_error("Module \"" + std::string(name) + "\" not found!");
return module(obj, false); return module(obj, false);
} }
// Adds an object to the module using the given name. Throws if an object with the given name
// already exists.
//
// overwrite should almost always be false: attempting to overwrite objects that pybind11 has
// established will, in most cases, break things.
PYBIND11_NOINLINE void add_object(const char *name, object &obj, bool overwrite = false) {
if (!overwrite && hasattr(*this, name))
pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" +
std::string(name) + "\"");
obj.inc_ref(); // PyModule_AddObject() steals a reference
PyModule_AddObject(ptr(), name, obj.ptr());
}
}; };
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
@ -614,6 +635,10 @@ protected:
object name(PYBIND11_FROM_STRING(rec->name), false); object name(PYBIND11_FROM_STRING(rec->name), false);
object scope_module; object scope_module;
if (rec->scope) { if (rec->scope) {
if (hasattr(rec->scope, rec->name))
pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec->name) +
"\": an object with that name is already defined");
if (hasattr(rec->scope, "__module__")) { if (hasattr(rec->scope, "__module__")) {
scope_module = rec->scope.attr("__module__"); scope_module = rec->scope.attr("__module__");
} else if (hasattr(rec->scope, "__name__")) { } else if (hasattr(rec->scope, "__name__")) {
@ -1357,8 +1382,7 @@ public:
+ std::string(".") + name; + std::string(".") + name;
char* exception_name = const_cast<char*>(full_name.c_str()); char* exception_name = const_cast<char*>(full_name.c_str());
m_ptr = PyErr_NewException(exception_name, base, NULL); m_ptr = PyErr_NewException(exception_name, base, NULL);
inc_ref(); // PyModule_AddObject() steals a reference m.add_object(name.c_str(), *this);
PyModule_AddObject(m.ptr(), name.c_str(), m_ptr);
} }
// Sets the current python exception to this exception object with the given message // Sets the current python exception to this exception object with the given message

View File

@ -36,6 +36,18 @@ OpTest2 operator+(const OpTest2 &, const OpTest1 &) {
return OpTest2(); return OpTest2();
} }
// #461
class Dupe1 {
public:
Dupe1(int v) : v_{v} {}
int get_value() const { return v_; }
private:
int v_;
};
class Dupe2 {};
class Dupe3 {};
class DupeException : public std::runtime_error {};
void init_issues(py::module &m) { void init_issues(py::module &m) {
py::module m2 = m.def_submodule("issues"); py::module m2 = m.def_submodule("issues");
@ -237,7 +249,49 @@ void init_issues(py::module &m) {
static std::vector<int> list = { 1, 2, 3 }; static std::vector<int> list = { 1, 2, 3 };
m2.def("make_iterator_1", []() { return py::make_iterator<py::return_value_policy::copy>(list); }); m2.def("make_iterator_1", []() { return py::make_iterator<py::return_value_policy::copy>(list); });
m2.def("make_iterator_2", []() { return py::make_iterator<py::return_value_policy::automatic>(list); }); m2.def("make_iterator_2", []() { return py::make_iterator<py::return_value_policy::automatic>(list); });
static std::vector<std::string> nothrows;
// Issue 461: registering two things with the same name:
py::class_<Dupe1>(m2, "Dupe1")
.def("get_value", &Dupe1::get_value)
;
m2.def("dupe1_factory", [](int v) { return new Dupe1(v); });
py::class_<Dupe2>(m2, "Dupe2");
py::exception<DupeException>(m2, "DupeException");
try {
m2.def("Dupe1", [](int v) { return new Dupe1(v); });
nothrows.emplace_back("Dupe1");
}
catch (std::runtime_error &) {}
try {
py::class_<Dupe3>(m2, "dupe1_factory");
nothrows.emplace_back("dupe1_factory");
}
catch (std::runtime_error &) {}
try {
py::exception<Dupe3>(m2, "Dupe2");
nothrows.emplace_back("Dupe2");
}
catch (std::runtime_error &) {}
try {
m2.def("DupeException", []() { return 30; });
nothrows.emplace_back("DupeException1");
}
catch (std::runtime_error &) {}
try {
py::class_<DupeException>(m2, "DupeException");
nothrows.emplace_back("DupeException2");
}
catch (std::runtime_error &) {}
m2.def("dupe_exception_failures", []() {
py::list l;
for (auto &e : nothrows) l.append(py::cast(e));
return l;
});
} }
// MSVC workaround: trying to use a lambda here crashes MSCV // MSVC workaround: trying to use a lambda here crashes MSCV
test_initializer issues(&init_issues); test_initializer issues(&init_issues);

View File

@ -182,3 +182,8 @@ def test_iterator_rvpolicy():
assert list(make_iterator_1()) == [1, 2, 3] assert list(make_iterator_1()) == [1, 2, 3]
assert list(make_iterator_2()) == [1, 2, 3] assert list(make_iterator_2()) == [1, 2, 3]
assert(type(make_iterator_1()) != type(make_iterator_2())) assert(type(make_iterator_1()) != type(make_iterator_2()))
def test_dupe_assignment():
""" Issue 461: overwriting a class with a function """
from pybind11_tests.issues import dupe_exception_failures
assert dupe_exception_failures() == []