diff --git a/docs/advanced/functions.rst b/docs/advanced/functions.rst index b11a2ab61..3b3eef92e 100644 --- a/docs/advanced/functions.rst +++ b/docs/advanced/functions.rst @@ -406,6 +406,53 @@ name, i.e. by specifying ``py::arg().noconvert()``. need to specify a ``py::arg()`` annotation for each argument with the no-convert argument modified to ``py::arg().noconvert()``. +Allow/Prohibiting None arguments +================================ + +When a C++ type registered with :class:`py::class_` is passed as an argument to +a function taking the instance as pointer or shared holder (e.g. ``shared_ptr`` +or a custom, copyable holder as described in :ref:`smart_pointers`), pybind +allows ``None`` to be passed from Python which results in calling the C++ +function with ``nullptr`` (or an empty holder) for the argument. + +To explicitly enable or disable this behaviour, using the +``.none`` method of the :class:`py::arg` object: + +.. code-block:: cpp + + py::class_(m, "Dog").def(py::init<>()); + py::class_(m, "Cat").def(py::init<>()); + m.def("bark", [](Dog *dog) -> std::string { + if (dog) return "woof!"; /* Called with a Dog instance */ + else return "(no dog)"; /* Called with None, d == nullptr */ + }, py::arg("dog").none(true)); + m.def("meow", [](Cat *cat) -> std::string { + // Can't be called with None argument + return "meow"; + }, py::arg("cat").none(false)); + +With the above, the Python call ``bark(None)`` will return the string ``"(no +dog)"``, while attempting to call ``meow(None)`` will throw a :exc:`TypeError`: + +.. code-block:: pycon + + >>> from animals import Dog, Cat, bark, meow + >>> bark(Dog()) + 'woof!' + >>> meow(Cat()) + 'meow' + >>> bark(None) + '(no dog)' + >>> meow(None) + Traceback (most recent call last): + File "", line 1, in + TypeError: meow(): incompatible function arguments. The following argument types are supported: + 1. (cat: animals.Cat) -> str + + Invoked with: None + +The default behaviour when the tag is unspecified is to allow ``None``. + Overload resolution order ========================= diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h index 84d2835c7..9f858d034 100644 --- a/include/pybind11/attr.h +++ b/include/pybind11/attr.h @@ -123,9 +123,10 @@ struct argument_record { const char *descr; ///< Human-readable version of the argument value handle value; ///< Associated Python object bool convert : 1; ///< True if the argument is allowed to convert when loading + bool none : 1; ///< True if None is allowed when loading - argument_record(const char *name, const char *descr, handle value, bool convert) - : name(name), descr(descr), value(value), convert(convert) { } + argument_record(const char *name, const char *descr, handle value, bool convert, bool none) + : name(name), descr(descr), value(value), convert(convert), none(none) { } }; /// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) @@ -338,8 +339,8 @@ template <> struct process_attribute : process_attribute_default struct process_attribute : process_attribute_default { static void init(const arg &a, function_record *r) { if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr, handle(), true /*convert*/); - r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert); + r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); + r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); } }; @@ -347,7 +348,7 @@ template <> struct process_attribute : process_attribute_default { template <> struct process_attribute : process_attribute_default { static void init(const arg_v &a, function_record *r) { if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/); + r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); if (!a.value) { #if !defined(NDEBUG) @@ -370,7 +371,7 @@ template <> struct process_attribute : process_attribute_default { "Compile in debug mode for more information."); #endif } - r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert); + r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); } }; diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 4c8512535..c997b90c0 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1381,14 +1381,17 @@ template arg_v operator=(T &&value) const; /// Indicate that the type should not be converted in the type caster arg &noconvert(bool flag = true) { flag_noconvert = flag; return *this; } + /// Indicates that the argument should/shouldn't allow None (e.g. for nullable pointer args) + arg &none(bool flag = true) { flag_none = flag; return *this; } const char *name; ///< If non-null, this is a named kwargs argument bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type caster!) + bool flag_none : 1; ///< If set (the default), allow None to be passed to this argument }; /// \ingroup annotations @@ -1421,6 +1424,9 @@ public: /// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg& arg_v &noconvert(bool flag = true) { arg::noconvert(flag); return *this; } + /// Same as `arg::nonone()`, but returns *this as arg_v&, not arg& + arg_v &none(bool flag = true) { arg::none(flag); return *this; } + /// The default value object value; /// The (optional) description of the default value diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 6ac8edc2b..c4e6131e4 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -466,18 +466,23 @@ protected: size_t args_copied = 0; // 1. Copy any position arguments given. - bool bad_kwarg = false; + bool bad_arg = false; for (; args_copied < args_to_copy; ++args_copied) { - if (kwargs_in && args_copied < func.args.size() && func.args[args_copied].name - && PyDict_GetItemString(kwargs_in, func.args[args_copied].name)) { - bad_kwarg = true; + argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr; + if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) { + bad_arg = true; break; } - call.args.push_back(PyTuple_GET_ITEM(args_in, args_copied)); - call.args_convert.push_back(args_copied < func.args.size() ? func.args[args_copied].convert : true); + handle arg(PyTuple_GET_ITEM(args_in, args_copied)); + if (arg_rec && !arg_rec->none && arg.is_none()) { + bad_arg = true; + break; + } + call.args.push_back(arg); + call.args_convert.push_back(arg_rec ? arg_rec->convert : true); } - if (bad_kwarg) + if (bad_arg) continue; // Maybe it was meant for another overload (issue #688) // We'll need to copy this if we steal some kwargs for defaults diff --git a/tests/test_methods_and_attributes.cpp b/tests/test_methods_and_attributes.cpp index dc98feeba..81b665b4d 100644 --- a/tests/test_methods_and_attributes.cpp +++ b/tests/test_methods_and_attributes.cpp @@ -162,6 +162,14 @@ public: /// Issue/PR #648: bad arg default debugging output class NotRegistered {}; +// Test None-allowed py::arg argument policy +class NoneTester { public: int answer = 42; }; +int none1(const NoneTester &obj) { return obj.answer; } +int none2(NoneTester *obj) { return obj ? obj->answer : -1; } +int none3(std::shared_ptr &obj) { return obj ? obj->answer : -1; } +int none4(std::shared_ptr *obj) { return obj && *obj ? (*obj)->answer : -1; } +int none5(std::shared_ptr obj) { return obj ? obj->answer : -1; } + test_initializer methods_and_attributes([](py::module &m) { py::class_ emna(m, "ExampleMandA"); emna.def(py::init<>()) @@ -322,4 +330,18 @@ test_initializer methods_and_attributes([](py::module &m) { auto m = py::module::import("pybind11_tests"); m.def("should_fail", [](int, NotRegistered) {}, py::arg(), py::arg() = NotRegistered()); }); + + py::class_>(m, "NoneTester") + .def(py::init<>()); + m.def("no_none1", &none1, py::arg().none(false)); + m.def("no_none2", &none2, py::arg().none(false)); + m.def("no_none3", &none3, py::arg().none(false)); + m.def("no_none4", &none4, py::arg().none(false)); + m.def("no_none5", &none5, py::arg().none(false)); + m.def("ok_none1", &none1); + m.def("ok_none2", &none2, py::arg().none(true)); + m.def("ok_none3", &none3); + m.def("ok_none4", &none4, py::arg().none(true)); + m.def("ok_none5", &none5); + }); diff --git a/tests/test_methods_and_attributes.py b/tests/test_methods_and_attributes.py index 8139decf0..3ec3eb76b 100644 --- a/tests/test_methods_and_attributes.py +++ b/tests/test_methods_and_attributes.py @@ -369,3 +369,47 @@ def test_bad_arg_default(msg): "arg(): could not convert default argument into a Python object (type not registered " "yet?). Compile in debug mode for more information." ) + + +def test_accepts_none(): + from pybind11_tests import (NoneTester, + no_none1, no_none2, no_none3, no_none4, no_none5, + ok_none1, ok_none2, ok_none3, ok_none4, ok_none5) + + a = NoneTester() + assert no_none1(a) == 42 + assert no_none2(a) == 42 + assert no_none3(a) == 42 + assert no_none4(a) == 42 + assert no_none5(a) == 42 + assert ok_none1(a) == 42 + assert ok_none2(a) == 42 + assert ok_none3(a) == 42 + assert ok_none4(a) == 42 + assert ok_none5(a) == 42 + + with pytest.raises(TypeError) as excinfo: + no_none1(None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + no_none2(None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + no_none3(None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + no_none4(None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + no_none5(None) + assert "incompatible function arguments" in str(excinfo.value) + + # The first one still raises because you can't pass None as a lvalue reference arg: + with pytest.raises(TypeError) as excinfo: + assert ok_none1(None) == -1 + assert "incompatible function arguments" in str(excinfo.value) + # The rest take the argument as pointer or holder, and accept None: + assert ok_none2(None) == -1 + assert ok_none3(None) == -1 + assert ok_none4(None) == -1 + assert ok_none5(None) == -1