mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 13:15:12 +00:00
Implicit conversions to bool + np.bool_ conversion (#925)
This adds support for implicit conversions to bool from Python types with `__bool__` (Python 3) or `__nonzero__` (Python 2) attributes, and adds direct (i.e. non-converting) support for numpy bools.
This commit is contained in:
parent
a03408c839
commit
e07f75839d
@ -1049,11 +1049,37 @@ template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_
|
||||
|
||||
template <> class type_caster<bool> {
|
||||
public:
|
||||
bool load(handle src, bool) {
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src) return false;
|
||||
else if (src.ptr() == Py_True) { value = true; return true; }
|
||||
else if (src.ptr() == Py_False) { value = false; return true; }
|
||||
else return false;
|
||||
else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
|
||||
// (allow non-implicit conversion for numpy booleans)
|
||||
|
||||
Py_ssize_t res = -1;
|
||||
if (src.is_none()) {
|
||||
res = 0; // None is implicitly converted to False
|
||||
}
|
||||
#if defined(PYPY_VERSION)
|
||||
// On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
|
||||
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
|
||||
res = PyObject_IsTrue(src.ptr());
|
||||
}
|
||||
#else
|
||||
// Alternate approach for CPython: this does the same as the above, but optimized
|
||||
// using the CPython API so as to avoid an unneeded attribute lookup.
|
||||
else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
|
||||
if (PYBIND11_NB_BOOL(tp_as_number)) {
|
||||
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (res == 0 || res == 1) {
|
||||
value = (bool) res;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return handle(src ? Py_True : Py_False).inc_ref();
|
||||
|
@ -152,8 +152,11 @@
|
||||
#define PYBIND11_SLICE_OBJECT PyObject
|
||||
#define PYBIND11_FROM_STRING PyUnicode_FromString
|
||||
#define PYBIND11_STR_TYPE ::pybind11::str
|
||||
#define PYBIND11_BOOL_ATTR "__bool__"
|
||||
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
|
||||
#define PYBIND11_PLUGIN_IMPL(name) \
|
||||
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
|
||||
|
||||
#else
|
||||
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
|
||||
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
|
||||
@ -171,6 +174,8 @@
|
||||
#define PYBIND11_SLICE_OBJECT PySliceObject
|
||||
#define PYBIND11_FROM_STRING PyString_FromString
|
||||
#define PYBIND11_STR_TYPE ::pybind11::bytes
|
||||
#define PYBIND11_BOOL_ATTR "__nonzero__"
|
||||
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
|
||||
#define PYBIND11_PLUGIN_IMPL(name) \
|
||||
static PyObject *pybind11_init_wrapper(); \
|
||||
extern "C" PYBIND11_EXPORT void init##name() { \
|
||||
|
@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) {
|
||||
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
|
||||
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });
|
||||
|
||||
// test_bool_caster
|
||||
m.def("bool_passthrough", [](bool arg) { return arg; });
|
||||
m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert());
|
||||
|
||||
// test_reference_wrapper
|
||||
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
|
||||
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });
|
||||
|
@ -265,3 +265,58 @@ def test_complex_cast():
|
||||
"""std::complex casts"""
|
||||
assert m.complex_cast(1) == "1.0"
|
||||
assert m.complex_cast(2j) == "(0.0, 2.0)"
|
||||
|
||||
|
||||
def test_bool_caster():
|
||||
"""Test bool caster implicit conversions."""
|
||||
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
|
||||
|
||||
def require_implicit(v):
|
||||
pytest.raises(TypeError, noconvert, v)
|
||||
|
||||
def cant_convert(v):
|
||||
pytest.raises(TypeError, convert, v)
|
||||
|
||||
# straight up bool
|
||||
assert convert(True) is True
|
||||
assert convert(False) is False
|
||||
assert noconvert(True) is True
|
||||
assert noconvert(False) is False
|
||||
|
||||
# None requires implicit conversion
|
||||
require_implicit(None)
|
||||
assert convert(None) is False
|
||||
|
||||
class A(object):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def __nonzero__(self):
|
||||
return self.x
|
||||
|
||||
def __bool__(self):
|
||||
return self.x
|
||||
|
||||
class B(object):
|
||||
pass
|
||||
|
||||
# Arbitrary objects are not accepted
|
||||
cant_convert(object())
|
||||
cant_convert(B())
|
||||
|
||||
# Objects with __nonzero__ / __bool__ defined can be converted
|
||||
require_implicit(A(True))
|
||||
assert convert(A(True)) is True
|
||||
assert convert(A(False)) is False
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_numpy_bool():
|
||||
import numpy as np
|
||||
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
|
||||
|
||||
# np.bool_ is not considered implicit
|
||||
assert convert(np.bool_(True)) is True
|
||||
assert convert(np.bool_(False)) is False
|
||||
assert noconvert(np.bool_(True)) is True
|
||||
assert noconvert(np.bool_(False)) is False
|
||||
|
Loading…
Reference in New Issue
Block a user