Fix #3812 and fix const of inplace assignments (#4065)

* Fix #3812 and fix const of inplace assignments

* Fix missing tests

* Revert operator overloading changes

* calculate answer first for tests

* Simplify tests

* Add more tests

* Add a couple more tests

* Add test_inplace_lshift, test_inplace_rshift for completeness.

* Update tests

* Shortcircuit on self assigment and address reviewer comment

* broaden skip for self assignment

* One more reviewer comment

* Document opt behavior and make consistent

* Revert unnecessary change

* Clarify comment

Co-authored-by: Ralf W. Grosse-Kunstleve <rwgk@google.com>
This commit is contained in:
Aaron Gokaslan 2022-07-20 11:42:24 -04:00 committed by GitHub
parent ef7d971e03
commit f47f1edfe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 24 deletions

View File

@ -155,23 +155,23 @@ public:
object operator-() const; object operator-() const;
object operator~() const; object operator~() const;
object operator+(object_api const &other) const; object operator+(object_api const &other) const;
object operator+=(object_api const &other) const; object operator+=(object_api const &other);
object operator-(object_api const &other) const; object operator-(object_api const &other) const;
object operator-=(object_api const &other) const; object operator-=(object_api const &other);
object operator*(object_api const &other) const; object operator*(object_api const &other) const;
object operator*=(object_api const &other) const; object operator*=(object_api const &other);
object operator/(object_api const &other) const; object operator/(object_api const &other) const;
object operator/=(object_api const &other) const; object operator/=(object_api const &other);
object operator|(object_api const &other) const; object operator|(object_api const &other) const;
object operator|=(object_api const &other) const; object operator|=(object_api const &other);
object operator&(object_api const &other) const; object operator&(object_api const &other) const;
object operator&=(object_api const &other) const; object operator&=(object_api const &other);
object operator^(object_api const &other) const; object operator^(object_api const &other) const;
object operator^=(object_api const &other) const; object operator^=(object_api const &other);
object operator<<(object_api const &other) const; object operator<<(object_api const &other) const;
object operator<<=(object_api const &other) const; object operator<<=(object_api const &other);
object operator>>(object_api const &other) const; object operator>>(object_api const &other) const;
object operator>>=(object_api const &other) const; object operator>>=(object_api const &other);
PYBIND11_DEPRECATED("Use py::str(obj) instead") PYBIND11_DEPRECATED("Use py::str(obj) instead")
pybind11::str str() const; pybind11::str str() const;
@ -334,12 +334,15 @@ public:
} }
object &operator=(const object &other) { object &operator=(const object &other) {
// Skip inc_ref and dec_ref if both objects are the same
if (!this->is(other)) {
other.inc_ref(); other.inc_ref();
// Use temporary variable to ensure `*this` remains valid while // Use temporary variable to ensure `*this` remains valid while
// `Py_XDECREF` executes, in case `*this` is accessible from Python. // `Py_XDECREF` executes, in case `*this` is accessible from Python.
handle temp(m_ptr); handle temp(m_ptr);
m_ptr = other.m_ptr; m_ptr = other.m_ptr;
temp.dec_ref(); temp.dec_ref();
}
return *this; return *this;
} }
@ -353,6 +356,20 @@ public:
return *this; return *this;
} }
#define PYBIND11_INPLACE_OP(iop) \
object iop(object_api const &other) { return operator=(handle::iop(other)); }
PYBIND11_INPLACE_OP(operator+=)
PYBIND11_INPLACE_OP(operator-=)
PYBIND11_INPLACE_OP(operator*=)
PYBIND11_INPLACE_OP(operator/=)
PYBIND11_INPLACE_OP(operator|=)
PYBIND11_INPLACE_OP(operator&=)
PYBIND11_INPLACE_OP(operator^=)
PYBIND11_INPLACE_OP(operator<<=)
PYBIND11_INPLACE_OP(operator>>=)
#undef PYBIND11_INPLACE_OP
// Calling cast() on an object lvalue just copies (via handle::cast) // Calling cast() on an object lvalue just copies (via handle::cast)
template <typename T> template <typename T>
T cast() const &; T cast() const &;
@ -2364,26 +2381,35 @@ bool object_api<D>::rich_compare(object_api const &other, int value) const {
return result; \ return result; \
} }
#define PYBIND11_MATH_OPERATOR_BINARY_INPLACE(iop, fn) \
template <typename D> \
object object_api<D>::iop(object_api const &other) { \
object result = reinterpret_steal<object>(fn(derived().ptr(), other.derived().ptr())); \
if (!result.ptr()) \
throw error_already_set(); \
return result; \
}
PYBIND11_MATH_OPERATOR_UNARY(operator~, PyNumber_Invert) PYBIND11_MATH_OPERATOR_UNARY(operator~, PyNumber_Invert)
PYBIND11_MATH_OPERATOR_UNARY(operator-, PyNumber_Negative) PYBIND11_MATH_OPERATOR_UNARY(operator-, PyNumber_Negative)
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add) PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator+=, PyNumber_InPlaceAdd)
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract) PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator-=, PyNumber_InPlaceSubtract)
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply) PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator*=, PyNumber_InPlaceMultiply)
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide) PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator/=, PyNumber_InPlaceTrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or) PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator|=, PyNumber_InPlaceOr)
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And) PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator&=, PyNumber_InPlaceAnd)
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor) PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator^=, PyNumber_InPlaceXor)
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift) PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator<<=, PyNumber_InPlaceLshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift) PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift) PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator>>=, PyNumber_InPlaceRshift)
#undef PYBIND11_MATH_OPERATOR_UNARY #undef PYBIND11_MATH_OPERATOR_UNARY
#undef PYBIND11_MATH_OPERATOR_BINARY #undef PYBIND11_MATH_OPERATOR_BINARY

View File

@ -756,4 +756,38 @@ TEST_SUBMODULE(pytypes, m) {
} }
return o; return o;
}); });
// testing immutable object augmented assignment: #issue 3812
m.def("inplace_append", [](py::object &a, const py::object &b) {
a += b;
return a;
});
m.def("inplace_subtract", [](py::object &a, const py::object &b) {
a -= b;
return a;
});
m.def("inplace_multiply", [](py::object &a, const py::object &b) {
a *= b;
return a;
});
m.def("inplace_divide", [](py::object &a, const py::object &b) {
a /= b;
return a;
});
m.def("inplace_or", [](py::object &a, const py::object &b) {
a |= b;
return a;
});
m.def("inplace_and", [](py::object &a, const py::object &b) {
a &= b;
return a;
});
m.def("inplace_lshift", [](py::object &a, const py::object &b) {
a <<= b;
return a;
});
m.def("inplace_rshift", [](py::object &a, const py::object &b) {
a >>= b;
return a;
});
} }

View File

@ -739,3 +739,75 @@ def test_populate_obj_str_attrs():
new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")} new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")}
assert all(isinstance(v, str) for v in new_attrs.values()) assert all(isinstance(v, str) for v in new_attrs.values())
assert len(new_attrs) == pop assert len(new_attrs) == pop
@pytest.mark.parametrize(
"a,b", [("foo", "bar"), (1, 2), (1.0, 2.0), (list(range(3)), list(range(3, 6)))]
)
def test_inplace_append(a, b):
expected = a + b
assert m.inplace_append(a, b) == expected
@pytest.mark.parametrize("a,b", [(3, 2), (3.0, 2.0), (set(range(3)), set(range(2)))])
def test_inplace_subtract(a, b):
expected = a - b
assert m.inplace_subtract(a, b) == expected
@pytest.mark.parametrize("a,b", [(3, 2), (3.0, 2.0), ([1], 3)])
def test_inplace_multiply(a, b):
expected = a * b
assert m.inplace_multiply(a, b) == expected
@pytest.mark.parametrize("a,b", [(6, 3), (6.0, 3.0)])
def test_inplace_divide(a, b):
expected = a / b
assert m.inplace_divide(a, b) == expected
@pytest.mark.parametrize(
"a,b",
[
(False, True),
(
set(),
{
1,
},
),
],
)
def test_inplace_or(a, b):
expected = a | b
assert m.inplace_or(a, b) == expected
@pytest.mark.parametrize(
"a,b",
[
(True, False),
(
{1, 2, 3},
{
1,
},
),
],
)
def test_inplace_and(a, b):
expected = a & b
assert m.inplace_and(a, b) == expected
@pytest.mark.parametrize("a,b", [(8, 1), (-3, 2)])
def test_inplace_lshift(a, b):
expected = a << b
assert m.inplace_lshift(a, b) == expected
@pytest.mark.parametrize("a,b", [(8, 1), (-2, 2)])
def test_inplace_rshift(a, b):
expected = a >> b
assert m.inplace_rshift(a, b) == expected