Implicit conversions to bool + np.bool_ conversion (#925)

This adds support for implicit conversions to bool from Python types
with `__bool__` (Python 3) or `__nonzero__` (Python 2) attributes, and
adds direct (i.e. non-converting) support for numpy bools.
This commit is contained in:
Ivan Smirnov 2017-07-23 16:02:43 +01:00 committed by Jason Rhinelander
parent a03408c839
commit e07f75839d
4 changed files with 92 additions and 2 deletions

View File

@ -1049,11 +1049,37 @@ template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_
template <> class type_caster<bool> { template <> class type_caster<bool> {
public: public:
bool load(handle src, bool) { bool load(handle src, bool convert) {
if (!src) return false; if (!src) return false;
else if (src.ptr() == Py_True) { value = true; return true; } else if (src.ptr() == Py_True) { value = true; return true; }
else if (src.ptr() == Py_False) { value = false; return true; } else if (src.ptr() == Py_False) { value = false; return true; }
else return false; else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
// (allow non-implicit conversion for numpy booleans)
Py_ssize_t res = -1;
if (src.is_none()) {
res = 0; // None is implicitly converted to False
}
#if defined(PYPY_VERSION)
// On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
res = PyObject_IsTrue(src.ptr());
}
#else
// Alternate approach for CPython: this does the same as the above, but optimized
// using the CPython API so as to avoid an unneeded attribute lookup.
else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
if (PYBIND11_NB_BOOL(tp_as_number)) {
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
}
}
#endif
if (res == 0 || res == 1) {
value = (bool) res;
return true;
}
}
return false;
} }
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
return handle(src ? Py_True : Py_False).inc_ref(); return handle(src ? Py_True : Py_False).inc_ref();

View File

@ -152,8 +152,11 @@
#define PYBIND11_SLICE_OBJECT PyObject #define PYBIND11_SLICE_OBJECT PyObject
#define PYBIND11_FROM_STRING PyUnicode_FromString #define PYBIND11_FROM_STRING PyUnicode_FromString
#define PYBIND11_STR_TYPE ::pybind11::str #define PYBIND11_STR_TYPE ::pybind11::str
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_PLUGIN_IMPL(name) \ #define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
#else #else
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_) #define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check #define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
@ -171,6 +174,8 @@
#define PYBIND11_SLICE_OBJECT PySliceObject #define PYBIND11_SLICE_OBJECT PySliceObject
#define PYBIND11_FROM_STRING PyString_FromString #define PYBIND11_FROM_STRING PyString_FromString
#define PYBIND11_STR_TYPE ::pybind11::bytes #define PYBIND11_STR_TYPE ::pybind11::bytes
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_PLUGIN_IMPL(name) \ #define PYBIND11_PLUGIN_IMPL(name) \
static PyObject *pybind11_init_wrapper(); \ static PyObject *pybind11_init_wrapper(); \
extern "C" PYBIND11_EXPORT void init##name() { \ extern "C" PYBIND11_EXPORT void init##name() { \

View File

@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) {
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; }); m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });
// test_bool_caster
m.def("bool_passthrough", [](bool arg) { return arg; });
m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert());
// test_reference_wrapper // test_reference_wrapper
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); }); m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); }); m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });

View File

@ -265,3 +265,58 @@ def test_complex_cast():
"""std::complex casts""" """std::complex casts"""
assert m.complex_cast(1) == "1.0" assert m.complex_cast(1) == "1.0"
assert m.complex_cast(2j) == "(0.0, 2.0)" assert m.complex_cast(2j) == "(0.0, 2.0)"
def test_bool_caster():
"""Test bool caster implicit conversions."""
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
def require_implicit(v):
pytest.raises(TypeError, noconvert, v)
def cant_convert(v):
pytest.raises(TypeError, convert, v)
# straight up bool
assert convert(True) is True
assert convert(False) is False
assert noconvert(True) is True
assert noconvert(False) is False
# None requires implicit conversion
require_implicit(None)
assert convert(None) is False
class A(object):
def __init__(self, x):
self.x = x
def __nonzero__(self):
return self.x
def __bool__(self):
return self.x
class B(object):
pass
# Arbitrary objects are not accepted
cant_convert(object())
cant_convert(B())
# Objects with __nonzero__ / __bool__ defined can be converted
require_implicit(A(True))
assert convert(A(True)) is True
assert convert(A(False)) is False
@pytest.requires_numpy
def test_numpy_bool():
import numpy as np
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
# np.bool_ is not considered implicit
assert convert(np.bool_(True)) is True
assert convert(np.bool_(False)) is False
assert noconvert(np.bool_(True)) is True
assert noconvert(np.bool_(False)) is False