From e07f75839ddba5e41a7c22cf1dddc73ddc1ce5ad Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Sun, 23 Jul 2017 16:02:43 +0100 Subject: [PATCH] Implicit conversions to bool + np.bool_ conversion (#925) This adds support for implicit conversions to bool from Python types with `__bool__` (Python 3) or `__nonzero__` (Python 2) attributes, and adds direct (i.e. non-converting) support for numpy bools. --- include/pybind11/cast.h | 30 +++++++++++++++++-- include/pybind11/common.h | 5 ++++ tests/test_builtin_casters.cpp | 4 +++ tests/test_builtin_casters.py | 55 ++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index f3a05f00b..0b41f942b 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1049,11 +1049,37 @@ template <> class type_caster : public void_caster class type_caster { public: - bool load(handle src, bool) { + bool load(handle src, bool convert) { if (!src) return false; else if (src.ptr() == Py_True) { value = true; return true; } else if (src.ptr() == Py_False) { value = false; return true; } - else return false; + else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { + // (allow non-implicit conversion for numpy booleans) + + Py_ssize_t res = -1; + if (src.is_none()) { + res = 0; // None is implicitly converted to False + } + #if defined(PYPY_VERSION) + // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists + else if (hasattr(src, PYBIND11_BOOL_ATTR)) { + res = PyObject_IsTrue(src.ptr()); + } + #else + // Alternate approach for CPython: this does the same as the above, but optimized + // using the CPython API so as to avoid an unneeded attribute lookup. + else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { + if (PYBIND11_NB_BOOL(tp_as_number)) { + res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); + } + } + #endif + if (res == 0 || res == 1) { + value = (bool) res; + return true; + } + } + return false; } static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { return handle(src ? Py_True : Py_False).inc_ref(); diff --git a/include/pybind11/common.h b/include/pybind11/common.h index ed1bb610b..b205bb9d8 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -152,8 +152,11 @@ #define PYBIND11_SLICE_OBJECT PyObject #define PYBIND11_FROM_STRING PyUnicode_FromString #define PYBIND11_STR_TYPE ::pybind11::str +#define PYBIND11_BOOL_ATTR "__bool__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) #define PYBIND11_PLUGIN_IMPL(name) \ extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() + #else #define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_) #define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check @@ -171,6 +174,8 @@ #define PYBIND11_SLICE_OBJECT PySliceObject #define PYBIND11_FROM_STRING PyString_FromString #define PYBIND11_STR_TYPE ::pybind11::bytes +#define PYBIND11_BOOL_ATTR "__nonzero__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero) #define PYBIND11_PLUGIN_IMPL(name) \ static PyObject *pybind11_init_wrapper(); \ extern "C" PYBIND11_EXPORT void init##name() { \ diff --git a/tests/test_builtin_casters.cpp b/tests/test_builtin_casters.cpp index bf972e19e..b73e96ea5 100644 --- a/tests/test_builtin_casters.cpp +++ b/tests/test_builtin_casters.cpp @@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) { m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile m.def("cast_nullptr_t", []() { return std::nullptr_t{}; }); + // test_bool_caster + m.def("bool_passthrough", [](bool arg) { return arg; }); + m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert()); + // test_reference_wrapper m.def("refwrap_builtin", [](std::reference_wrapper p) { return 10 * p.get(); }); m.def("refwrap_usertype", [](std::reference_wrapper p) { return p.get().value(); }); diff --git a/tests/test_builtin_casters.py b/tests/test_builtin_casters.py index d7d49b699..64d2f1945 100644 --- a/tests/test_builtin_casters.py +++ b/tests/test_builtin_casters.py @@ -265,3 +265,58 @@ def test_complex_cast(): """std::complex casts""" assert m.complex_cast(1) == "1.0" assert m.complex_cast(2j) == "(0.0, 2.0)" + + +def test_bool_caster(): + """Test bool caster implicit conversions.""" + convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert + + def require_implicit(v): + pytest.raises(TypeError, noconvert, v) + + def cant_convert(v): + pytest.raises(TypeError, convert, v) + + # straight up bool + assert convert(True) is True + assert convert(False) is False + assert noconvert(True) is True + assert noconvert(False) is False + + # None requires implicit conversion + require_implicit(None) + assert convert(None) is False + + class A(object): + def __init__(self, x): + self.x = x + + def __nonzero__(self): + return self.x + + def __bool__(self): + return self.x + + class B(object): + pass + + # Arbitrary objects are not accepted + cant_convert(object()) + cant_convert(B()) + + # Objects with __nonzero__ / __bool__ defined can be converted + require_implicit(A(True)) + assert convert(A(True)) is True + assert convert(A(False)) is False + + +@pytest.requires_numpy +def test_numpy_bool(): + import numpy as np + convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert + + # np.bool_ is not considered implicit + assert convert(np.bool_(True)) is True + assert convert(np.bool_(False)) is False + assert noconvert(np.bool_(True)) is True + assert noconvert(np.bool_(False)) is False