From 8310aa46760a07d158035c5ab0566edd51a4819d Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Fri, 15 Dec 2017 13:22:15 -0400 Subject: [PATCH] Added py::args ref counting tests --- tests/test_kwargs_and_defaults.cpp | 29 ++++++++++++++++++++++ tests/test_kwargs_and_defaults.py | 40 ++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/tests/test_kwargs_and_defaults.cpp b/tests/test_kwargs_and_defaults.cpp index 165f8017e..2263b6b7a 100644 --- a/tests/test_kwargs_and_defaults.cpp +++ b/tests/test_kwargs_and_defaults.cpp @@ -8,6 +8,7 @@ */ #include "pybind11_tests.h" +#include "constructor_stats.h" #include TEST_SUBMODULE(kwargs_and_defaults, m) { @@ -53,6 +54,34 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { m.def("mixed_plus_args_kwargs_defaults", mixed_plus_both, py::arg("i") = 1, py::arg("j") = 3.14159); + // test_args_refcount + // PyPy needs a garbage collection to get the reference count values to match CPython's behaviour + #ifdef PYPY_VERSION + #define GC_IF_NEEDED ConstructorStats::gc() + #else + #define GC_IF_NEEDED + #endif + m.def("arg_refcount_h", [](py::handle h) { GC_IF_NEEDED; return h.ref_count(); }); + m.def("arg_refcount_h", [](py::handle h, py::handle, py::handle) { GC_IF_NEEDED; return h.ref_count(); }); + m.def("arg_refcount_o", [](py::object o) { GC_IF_NEEDED; return o.ref_count(); }); + m.def("args_refcount", [](py::args a) { + GC_IF_NEEDED; + py::tuple t(a.size()); + for (size_t i = 0; i < a.size(); i++) + // Use raw Python API here to avoid an extra, intermediate incref on the tuple item: + t[i] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); + return t; + }); + m.def("mixed_args_refcount", [](py::object o, py::args a) { + GC_IF_NEEDED; + py::tuple t(a.size() + 1); + t[0] = o.ref_count(); + for (size_t i = 0; i < a.size(); i++) + // Use raw Python API here to avoid an extra, intermediate incref on the tuple item: + t[i + 1] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); + return t; + }); + // pybind11 won't allow these to be bound: args and kwargs, if present, must be at the end. // Uncomment these to test that the static_assert is indeed working: // m.def("bad_args1", [](py::args, int) {}); diff --git a/tests/test_kwargs_and_defaults.py b/tests/test_kwargs_and_defaults.py index 733fe8593..269587656 100644 --- a/tests/test_kwargs_and_defaults.py +++ b/tests/test_kwargs_and_defaults.py @@ -105,3 +105,43 @@ def test_mixed_args_and_kwargs(msg): Invoked with: 1, 2; kwargs: j=1 """ # noqa: E501 line too long + + +def test_args_refcount(): + """Issue/PR #1216 - py::args elements get double-inc_ref()ed when combined with regular + arguments""" + refcount = m.arg_refcount_h + + myval = 54321 + expected = refcount(myval) + assert m.arg_refcount_h(myval) == expected + assert m.arg_refcount_o(myval) == expected + 1 + assert m.arg_refcount_h(myval) == expected + assert refcount(myval) == expected + + assert m.mixed_plus_args(1, 2.0, "a", myval) == (1, 2.0, ("a", myval)) + assert refcount(myval) == expected + + assert m.mixed_plus_kwargs(3, 4.0, a=1, b=myval) == (3, 4.0, {"a": 1, "b": myval}) + assert refcount(myval) == expected + + assert m.args_function(-1, myval) == (-1, myval) + assert refcount(myval) == expected + + assert m.mixed_plus_args_kwargs(5, 6.0, myval, a=myval) == (5, 6.0, (myval,), {"a": myval}) + assert refcount(myval) == expected + + assert m.args_kwargs_function(7, 8, myval, a=1, b=myval) == \ + ((7, 8, myval), {"a": 1, "b": myval}) + assert refcount(myval) == expected + + exp3 = refcount(myval, myval, myval) + assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3) + assert refcount(myval) == expected + + # This function takes the first arg as a `py::object` and the rest as a `py::args`. Unlike the + # previous case, when we have both positional and `py::args` we need to construct a new tuple + # for the `py::args`; in the previous case, we could simply inc_ref and pass on Python's input + # tuple without having to inc_ref the individual elements, but here we can't, hence the extra + # refs. + assert m.mixed_args_refcount(myval, myval, myval) == (exp3 + 3, exp3 + 3, exp3 + 3)