mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-26 07:02:11 +00:00
Implicit conversions to bool + np.bool_ conversion (#925)
This adds support for implicit conversions to bool from Python types with `__bool__` (Python 3) or `__nonzero__` (Python 2) attributes, and adds direct (i.e. non-converting) support for numpy bools.
This commit is contained in:
parent
a03408c839
commit
e07f75839d
@ -1049,11 +1049,37 @@ template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_
|
|||||||
|
|
||||||
template <> class type_caster<bool> {
|
template <> class type_caster<bool> {
|
||||||
public:
|
public:
|
||||||
bool load(handle src, bool) {
|
bool load(handle src, bool convert) {
|
||||||
if (!src) return false;
|
if (!src) return false;
|
||||||
else if (src.ptr() == Py_True) { value = true; return true; }
|
else if (src.ptr() == Py_True) { value = true; return true; }
|
||||||
else if (src.ptr() == Py_False) { value = false; return true; }
|
else if (src.ptr() == Py_False) { value = false; return true; }
|
||||||
else return false;
|
else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
|
||||||
|
// (allow non-implicit conversion for numpy booleans)
|
||||||
|
|
||||||
|
Py_ssize_t res = -1;
|
||||||
|
if (src.is_none()) {
|
||||||
|
res = 0; // None is implicitly converted to False
|
||||||
|
}
|
||||||
|
#if defined(PYPY_VERSION)
|
||||||
|
// On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
|
||||||
|
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
|
||||||
|
res = PyObject_IsTrue(src.ptr());
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// Alternate approach for CPython: this does the same as the above, but optimized
|
||||||
|
// using the CPython API so as to avoid an unneeded attribute lookup.
|
||||||
|
else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
|
||||||
|
if (PYBIND11_NB_BOOL(tp_as_number)) {
|
||||||
|
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (res == 0 || res == 1) {
|
||||||
|
value = (bool) res;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
|
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
|
||||||
return handle(src ? Py_True : Py_False).inc_ref();
|
return handle(src ? Py_True : Py_False).inc_ref();
|
||||||
|
@ -152,8 +152,11 @@
|
|||||||
#define PYBIND11_SLICE_OBJECT PyObject
|
#define PYBIND11_SLICE_OBJECT PyObject
|
||||||
#define PYBIND11_FROM_STRING PyUnicode_FromString
|
#define PYBIND11_FROM_STRING PyUnicode_FromString
|
||||||
#define PYBIND11_STR_TYPE ::pybind11::str
|
#define PYBIND11_STR_TYPE ::pybind11::str
|
||||||
|
#define PYBIND11_BOOL_ATTR "__bool__"
|
||||||
|
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
|
||||||
#define PYBIND11_PLUGIN_IMPL(name) \
|
#define PYBIND11_PLUGIN_IMPL(name) \
|
||||||
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
|
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
|
||||||
|
|
||||||
#else
|
#else
|
||||||
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
|
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
|
||||||
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
|
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
|
||||||
@ -171,6 +174,8 @@
|
|||||||
#define PYBIND11_SLICE_OBJECT PySliceObject
|
#define PYBIND11_SLICE_OBJECT PySliceObject
|
||||||
#define PYBIND11_FROM_STRING PyString_FromString
|
#define PYBIND11_FROM_STRING PyString_FromString
|
||||||
#define PYBIND11_STR_TYPE ::pybind11::bytes
|
#define PYBIND11_STR_TYPE ::pybind11::bytes
|
||||||
|
#define PYBIND11_BOOL_ATTR "__nonzero__"
|
||||||
|
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
|
||||||
#define PYBIND11_PLUGIN_IMPL(name) \
|
#define PYBIND11_PLUGIN_IMPL(name) \
|
||||||
static PyObject *pybind11_init_wrapper(); \
|
static PyObject *pybind11_init_wrapper(); \
|
||||||
extern "C" PYBIND11_EXPORT void init##name() { \
|
extern "C" PYBIND11_EXPORT void init##name() { \
|
||||||
|
@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) {
|
|||||||
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
|
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
|
||||||
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });
|
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });
|
||||||
|
|
||||||
|
// test_bool_caster
|
||||||
|
m.def("bool_passthrough", [](bool arg) { return arg; });
|
||||||
|
m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert());
|
||||||
|
|
||||||
// test_reference_wrapper
|
// test_reference_wrapper
|
||||||
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
|
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
|
||||||
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });
|
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });
|
||||||
|
@ -265,3 +265,58 @@ def test_complex_cast():
|
|||||||
"""std::complex casts"""
|
"""std::complex casts"""
|
||||||
assert m.complex_cast(1) == "1.0"
|
assert m.complex_cast(1) == "1.0"
|
||||||
assert m.complex_cast(2j) == "(0.0, 2.0)"
|
assert m.complex_cast(2j) == "(0.0, 2.0)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bool_caster():
|
||||||
|
"""Test bool caster implicit conversions."""
|
||||||
|
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
|
||||||
|
|
||||||
|
def require_implicit(v):
|
||||||
|
pytest.raises(TypeError, noconvert, v)
|
||||||
|
|
||||||
|
def cant_convert(v):
|
||||||
|
pytest.raises(TypeError, convert, v)
|
||||||
|
|
||||||
|
# straight up bool
|
||||||
|
assert convert(True) is True
|
||||||
|
assert convert(False) is False
|
||||||
|
assert noconvert(True) is True
|
||||||
|
assert noconvert(False) is False
|
||||||
|
|
||||||
|
# None requires implicit conversion
|
||||||
|
require_implicit(None)
|
||||||
|
assert convert(None) is False
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
def __init__(self, x):
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
def __nonzero__(self):
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
class B(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Arbitrary objects are not accepted
|
||||||
|
cant_convert(object())
|
||||||
|
cant_convert(B())
|
||||||
|
|
||||||
|
# Objects with __nonzero__ / __bool__ defined can be converted
|
||||||
|
require_implicit(A(True))
|
||||||
|
assert convert(A(True)) is True
|
||||||
|
assert convert(A(False)) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.requires_numpy
|
||||||
|
def test_numpy_bool():
|
||||||
|
import numpy as np
|
||||||
|
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
|
||||||
|
|
||||||
|
# np.bool_ is not considered implicit
|
||||||
|
assert convert(np.bool_(True)) is True
|
||||||
|
assert convert(np.bool_(False)) is False
|
||||||
|
assert noconvert(np.bool_(True)) is True
|
||||||
|
assert noconvert(np.bool_(False)) is False
|
||||||
|
Loading…
Reference in New Issue
Block a user