mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 05:05:11 +00:00
object_api: support the number protocol
This commit revamps the object_api class so that it maps most C++ operators to their Python analogs. This makes it possible to, e.g. perform arithmetic using a py::int_ or py::array.
This commit is contained in:
parent
5c8746ff13
commit
067100201f
@ -114,6 +114,35 @@ public:
|
||||
bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); }
|
||||
/// Equivalent to ``obj is None`` in Python.
|
||||
bool is_none() const { return derived().ptr() == Py_None; }
|
||||
/// Equivalent to obj == other in Python
|
||||
bool equal(object_api const &other) const { return rich_compare(other, Py_EQ); }
|
||||
bool not_equal(object_api const &other) const { return rich_compare(other, Py_NE); }
|
||||
bool operator<(object_api const &other) const { return rich_compare(other, Py_LT); }
|
||||
bool operator<=(object_api const &other) const { return rich_compare(other, Py_LE); }
|
||||
bool operator>(object_api const &other) const { return rich_compare(other, Py_GT); }
|
||||
bool operator>=(object_api const &other) const { return rich_compare(other, Py_GE); }
|
||||
|
||||
object operator-() const;
|
||||
object operator~() const;
|
||||
object operator+(object_api const &other) const;
|
||||
object operator+=(object_api const &other) const;
|
||||
object operator-(object_api const &other) const;
|
||||
object operator-=(object_api const &other) const;
|
||||
object operator*(object_api const &other) const;
|
||||
object operator*=(object_api const &other) const;
|
||||
object operator/(object_api const &other) const;
|
||||
object operator/=(object_api const &other) const;
|
||||
object operator|(object_api const &other) const;
|
||||
object operator|=(object_api const &other) const;
|
||||
object operator&(object_api const &other) const;
|
||||
object operator&=(object_api const &other) const;
|
||||
object operator^(object_api const &other) const;
|
||||
object operator^=(object_api const &other) const;
|
||||
object operator<<(object_api const &other) const;
|
||||
object operator<<=(object_api const &other) const;
|
||||
object operator>>(object_api const &other) const;
|
||||
object operator>>=(object_api const &other) const;
|
||||
|
||||
PYBIND11_DEPRECATED("Use py::str(obj) instead")
|
||||
pybind11::str str() const;
|
||||
|
||||
@ -124,6 +153,9 @@ public:
|
||||
int ref_count() const { return static_cast<int>(Py_REFCNT(derived().ptr())); }
|
||||
/// Return a handle to the Python type object underlying the instance
|
||||
handle get_type() const;
|
||||
|
||||
private:
|
||||
bool rich_compare(object_api const &other, int value) const;
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
@ -1342,5 +1374,55 @@ str_attr_accessor object_api<D>::doc() const { return attr("__doc__"); }
|
||||
template <typename D>
|
||||
handle object_api<D>::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); }
|
||||
|
||||
template <typename D>
|
||||
bool object_api<D>::rich_compare(object_api const &other, int value) const {
|
||||
int rv = PyObject_RichCompareBool(derived().ptr(), other.derived().ptr(), value);
|
||||
if (rv == -1)
|
||||
throw error_already_set();
|
||||
return rv == 1;
|
||||
}
|
||||
|
||||
#define PYBIND11_MATH_OPERATOR_UNARY(op, fn) \
|
||||
template <typename D> object object_api<D>::op() const { \
|
||||
object result = reinterpret_steal<object>(fn(derived().ptr())); \
|
||||
if (!result.ptr()) \
|
||||
throw error_already_set(); \
|
||||
return result; \
|
||||
}
|
||||
|
||||
#define PYBIND11_MATH_OPERATOR_BINARY(op, fn) \
|
||||
template <typename D> \
|
||||
object object_api<D>::op(object_api const &other) const { \
|
||||
object result = reinterpret_steal<object>( \
|
||||
fn(derived().ptr(), other.derived().ptr())); \
|
||||
if (!result.ptr()) \
|
||||
throw error_already_set(); \
|
||||
return result; \
|
||||
}
|
||||
|
||||
PYBIND11_MATH_OPERATOR_UNARY (operator~, PyNumber_Invert)
|
||||
PYBIND11_MATH_OPERATOR_UNARY (operator-, PyNumber_Negative)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift)
|
||||
|
||||
#undef PYBIND11_MATH_OPERATOR_UNARY
|
||||
#undef PYBIND11_MATH_OPERATOR_BINARY
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@ -269,4 +269,24 @@ TEST_SUBMODULE(pytypes, m) {
|
||||
m.def("print_failure", []() { py::print(42, UnregisteredType()); });
|
||||
|
||||
m.def("hash_function", [](py::object obj) { return py::hash(obj); });
|
||||
|
||||
m.def("test_number_protocol", [](py::object a, py::object b) {
|
||||
py::list l;
|
||||
l.append(a.equal(b));
|
||||
l.append(a.not_equal(b));
|
||||
l.append(a < b);
|
||||
l.append(a <= b);
|
||||
l.append(a > b);
|
||||
l.append(a >= b);
|
||||
l.append(a + b);
|
||||
l.append(a - b);
|
||||
l.append(a * b);
|
||||
l.append(a / b);
|
||||
l.append(a | b);
|
||||
l.append(a & b);
|
||||
l.append(a ^ b);
|
||||
l.append(a >> b);
|
||||
l.append(a << b);
|
||||
return l;
|
||||
});
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
from __future__ import division
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
@ -238,3 +239,10 @@ def test_hash():
|
||||
assert m.hash_function(Hashable(42)) == 42
|
||||
with pytest.raises(TypeError):
|
||||
m.hash_function(Unhashable())
|
||||
|
||||
|
||||
def test_number_protocol():
|
||||
for a, b in [(1, 1), (3, 5)]:
|
||||
li = [a == b, a != b, a < b, a <= b, a > b, a >= b, a + b,
|
||||
a - b, a * b, a / b, a | b, a & b, a ^ b, a >> b, a << b]
|
||||
assert m.test_number_protocol(a, b) == li
|
||||
|
Loading…
Reference in New Issue
Block a user