object_api: support the number protocol

This commit revamps the object_api class so that it maps most C++
operators to their Python analogs. This makes it possible to, e.g.
perform arithmetic using a py::int_ or py::array.
This commit is contained in:
Wenzel Jakob 2018-09-01 01:09:16 +02:00
parent 5c8746ff13
commit 067100201f
3 changed files with 110 additions and 0 deletions

View File

@ -114,6 +114,35 @@ public:
bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); }
/// Equivalent to ``obj is None`` in Python.
bool is_none() const { return derived().ptr() == Py_None; }
/// Equivalent to obj == other in Python
bool equal(object_api const &other) const { return rich_compare(other, Py_EQ); }
bool not_equal(object_api const &other) const { return rich_compare(other, Py_NE); }
bool operator<(object_api const &other) const { return rich_compare(other, Py_LT); }
bool operator<=(object_api const &other) const { return rich_compare(other, Py_LE); }
bool operator>(object_api const &other) const { return rich_compare(other, Py_GT); }
bool operator>=(object_api const &other) const { return rich_compare(other, Py_GE); }
object operator-() const;
object operator~() const;
object operator+(object_api const &other) const;
object operator+=(object_api const &other) const;
object operator-(object_api const &other) const;
object operator-=(object_api const &other) const;
object operator*(object_api const &other) const;
object operator*=(object_api const &other) const;
object operator/(object_api const &other) const;
object operator/=(object_api const &other) const;
object operator|(object_api const &other) const;
object operator|=(object_api const &other) const;
object operator&(object_api const &other) const;
object operator&=(object_api const &other) const;
object operator^(object_api const &other) const;
object operator^=(object_api const &other) const;
object operator<<(object_api const &other) const;
object operator<<=(object_api const &other) const;
object operator>>(object_api const &other) const;
object operator>>=(object_api const &other) const;
PYBIND11_DEPRECATED("Use py::str(obj) instead")
pybind11::str str() const;
@ -124,6 +153,9 @@ public:
int ref_count() const { return static_cast<int>(Py_REFCNT(derived().ptr())); }
/// Return a handle to the Python type object underlying the instance
handle get_type() const;
private:
bool rich_compare(object_api const &other, int value) const;
};
NAMESPACE_END(detail)
@ -1342,5 +1374,55 @@ str_attr_accessor object_api<D>::doc() const { return attr("__doc__"); }
template <typename D>
handle object_api<D>::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); }
template <typename D>
bool object_api<D>::rich_compare(object_api const &other, int value) const {
int rv = PyObject_RichCompareBool(derived().ptr(), other.derived().ptr(), value);
if (rv == -1)
throw error_already_set();
return rv == 1;
}
#define PYBIND11_MATH_OPERATOR_UNARY(op, fn) \
template <typename D> object object_api<D>::op() const { \
object result = reinterpret_steal<object>(fn(derived().ptr())); \
if (!result.ptr()) \
throw error_already_set(); \
return result; \
}
#define PYBIND11_MATH_OPERATOR_BINARY(op, fn) \
template <typename D> \
object object_api<D>::op(object_api const &other) const { \
object result = reinterpret_steal<object>( \
fn(derived().ptr(), other.derived().ptr())); \
if (!result.ptr()) \
throw error_already_set(); \
return result; \
}
PYBIND11_MATH_OPERATOR_UNARY (operator~, PyNumber_Invert)
PYBIND11_MATH_OPERATOR_UNARY (operator-, PyNumber_Negative)
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd)
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract)
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply)
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr)
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd)
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor)
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift)
#undef PYBIND11_MATH_OPERATOR_UNARY
#undef PYBIND11_MATH_OPERATOR_BINARY
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@ -269,4 +269,24 @@ TEST_SUBMODULE(pytypes, m) {
m.def("print_failure", []() { py::print(42, UnregisteredType()); });
m.def("hash_function", [](py::object obj) { return py::hash(obj); });
m.def("test_number_protocol", [](py::object a, py::object b) {
py::list l;
l.append(a.equal(b));
l.append(a.not_equal(b));
l.append(a < b);
l.append(a <= b);
l.append(a > b);
l.append(a >= b);
l.append(a + b);
l.append(a - b);
l.append(a * b);
l.append(a / b);
l.append(a | b);
l.append(a & b);
l.append(a ^ b);
l.append(a >> b);
l.append(a << b);
return l;
});
}

View File

@ -1,3 +1,4 @@
from __future__ import division
import pytest
import sys
@ -238,3 +239,10 @@ def test_hash():
assert m.hash_function(Hashable(42)) == 42
with pytest.raises(TypeError):
m.hash_function(Unhashable())
def test_number_protocol():
for a, b in [(1, 1), (3, 5)]:
li = [a == b, a != b, a < b, a <= b, a > b, a >= b, a + b,
a - b, a * b, a / b, a | b, a & b, a ^ b, a >> b, a << b]
assert m.test_number_protocol(a, b) == li