Fix /= operator under Python 3

The Python method for /= was set as `__idiv__`, which should be
`__itruediv__` under Python 3.

This wasn't totally broken in that without it defined, Python constructs
a new object by calling __truediv__.  The operator tests, however,
didn't actually test the /= operator: when I added it, I saw an extra
construction, leading to the problem.  This commit also includes tests
for the previously untested *= operator, and adds some element-wise
vector multiplication and division operators.
This commit is contained in:
Jason Rhinelander 2017-05-20 20:19:26 -04:00
parent d2da33a34a
commit acad05cb13
3 changed files with 28 additions and 3 deletions

View File

@ -25,7 +25,7 @@ enum op_id : int {
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
op_repr, op_truediv op_repr, op_truediv, op_itruediv
}; };
enum op_type : int { enum op_type : int {
@ -129,7 +129,11 @@ PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
#if PY_MAJOR_VERSION >= 3
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
#else
PYBIND11_INPLACE_OPERATOR(idiv, operator/=, l /= r) PYBIND11_INPLACE_OPERATOR(idiv, operator/=, l /= r)
#endif
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)

View File

@ -39,10 +39,14 @@ public:
Vector2 operator+(float value) const { return Vector2(x + value, y + value); } Vector2 operator+(float value) const { return Vector2(x + value, y + value); }
Vector2 operator*(float value) const { return Vector2(x * value, y * value); } Vector2 operator*(float value) const { return Vector2(x * value, y * value); }
Vector2 operator/(float value) const { return Vector2(x / value, y / value); } Vector2 operator/(float value) const { return Vector2(x / value, y / value); }
Vector2 operator*(const Vector2 &v) const { return Vector2(x * v.x, y * v.y); }
Vector2 operator/(const Vector2 &v) const { return Vector2(x / v.x, y / v.y); }
Vector2& operator+=(const Vector2 &v) { x += v.x; y += v.y; return *this; } Vector2& operator+=(const Vector2 &v) { x += v.x; y += v.y; return *this; }
Vector2& operator-=(const Vector2 &v) { x -= v.x; y -= v.y; return *this; } Vector2& operator-=(const Vector2 &v) { x -= v.x; y -= v.y; return *this; }
Vector2& operator*=(float v) { x *= v; y *= v; return *this; } Vector2& operator*=(float v) { x *= v; y *= v; return *this; }
Vector2& operator/=(float v) { x /= v; y /= v; return *this; } Vector2& operator/=(float v) { x /= v; y /= v; return *this; }
Vector2& operator*=(const Vector2 &v) { x *= v.x; y *= v.y; return *this; }
Vector2& operator/=(const Vector2 &v) { x /= v.x; y /= v.y; return *this; }
friend Vector2 operator+(float f, const Vector2 &v) { return Vector2(f + v.x, f + v.y); } friend Vector2 operator+(float f, const Vector2 &v) { return Vector2(f + v.x, f + v.y); }
friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); } friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); }
@ -61,10 +65,14 @@ test_initializer operator_overloading([](py::module &m) {
.def(py::self - float()) .def(py::self - float())
.def(py::self * float()) .def(py::self * float())
.def(py::self / float()) .def(py::self / float())
.def(py::self * py::self)
.def(py::self / py::self)
.def(py::self += py::self) .def(py::self += py::self)
.def(py::self -= py::self) .def(py::self -= py::self)
.def(py::self *= float()) .def(py::self *= float())
.def(py::self /= float()) .def(py::self /= float())
.def(py::self *= py::self)
.def(py::self /= py::self)
.def(float() + py::self) .def(float() + py::self)
.def(float() - py::self) .def(float() - py::self)
.def(float() * py::self) .def(float() * py::self)

View File

@ -16,10 +16,21 @@ def test_operator_overloading():
assert str(8 + v1) == "[9.000000, 10.000000]" assert str(8 + v1) == "[9.000000, 10.000000]"
assert str(8 * v1) == "[8.000000, 16.000000]" assert str(8 * v1) == "[8.000000, 16.000000]"
assert str(8 / v1) == "[8.000000, 4.000000]" assert str(8 / v1) == "[8.000000, 4.000000]"
assert str(v1 * v2) == "[3.000000, -2.000000]"
assert str(v2 / v1) == "[3.000000, -0.500000]"
v1 += v2 v1 += 2 * v2
assert str(v1) == "[7.000000, 0.000000]"
v1 -= v2
assert str(v1) == "[4.000000, 1.000000]"
v1 *= 2 v1 *= 2
assert str(v1) == "[8.000000, 2.000000]" assert str(v1) == "[8.000000, 2.000000]"
v1 /= 16
assert str(v1) == "[0.500000, 0.125000]"
v1 *= v2
assert str(v1) == "[1.500000, -0.125000]"
v2 /= v1
assert str(v2) == "[2.000000, 8.000000]"
cstats = ConstructorStats.get(Vector2) cstats = ConstructorStats.get(Vector2)
assert cstats.alive() == 2 assert cstats.alive() == 2
@ -32,7 +43,9 @@ def test_operator_overloading():
'[-7.000000, -6.000000]', '[9.000000, 10.000000]', '[-7.000000, -6.000000]', '[9.000000, 10.000000]',
'[8.000000, 16.000000]', '[0.125000, 0.250000]', '[8.000000, 16.000000]', '[0.125000, 0.250000]',
'[7.000000, 6.000000]', '[9.000000, 10.000000]', '[7.000000, 6.000000]', '[9.000000, 10.000000]',
'[8.000000, 16.000000]', '[8.000000, 4.000000]'] '[8.000000, 16.000000]', '[8.000000, 4.000000]',
'[3.000000, -2.000000]', '[3.000000, -0.500000]',
'[6.000000, -2.000000]']
assert cstats.default_constructions == 0 assert cstats.default_constructions == 0
assert cstats.copy_constructions == 0 assert cstats.copy_constructions == 0
assert cstats.move_constructions >= 10 assert cstats.move_constructions >= 10