Allow subclasses of py::args and py::kwargs (#5381)

* Allow subclasses of py::args and py::kwargs

The current implementation does not allow subclasses of args or kwargs.
This change allows subclasses to be used.

* Added test case

* style: pre-commit fixes

* Added missing semi-colons

* style: pre-commit fixes

* Added handle_type_name

* Moved classes outside of function

* Added namespaces

* style: pre-commit fixes

* Refactored tests

Added more tests and moved tests to more appropriate locations.

* style: pre-commit fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
gentlegiantJGC 2024-09-24 18:28:22 +01:00 committed by GitHub
parent 1f8b4a7f1a
commit 7e418f4924
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 2 deletions

View File

@ -1570,9 +1570,9 @@ class argument_loader {
using indices = make_index_sequence<sizeof...(Args)>; using indices = make_index_sequence<sizeof...(Args)>;
template <typename Arg> template <typename Arg>
using argument_is_args = std::is_same<intrinsic_t<Arg>, args>; using argument_is_args = std::is_base_of<args, intrinsic_t<Arg>>;
template <typename Arg> template <typename Arg>
using argument_is_kwargs = std::is_same<intrinsic_t<Arg>, kwargs>; using argument_is_kwargs = std::is_base_of<kwargs, intrinsic_t<Arg>>;
// Get kwargs argument position, or -1 if not present: // Get kwargs argument position, or -1 if not present:
static constexpr auto kwargs_pos = constexpr_last<argument_is_kwargs, Args...>(); static constexpr auto kwargs_pos = constexpr_last<argument_is_kwargs, Args...>();

View File

@ -14,6 +14,26 @@
#include <utility> #include <utility>
// Classes needed for subclass test.
class ArgsSubclass : public py::args {
using py::args::args;
};
class KWArgsSubclass : public py::kwargs {
using py::kwargs::kwargs;
};
namespace pybind11 {
namespace detail {
template <>
struct handle_type_name<ArgsSubclass> {
static constexpr auto name = const_name("*Args");
};
template <>
struct handle_type_name<KWArgsSubclass> {
static constexpr auto name = const_name("**KWArgs");
};
} // namespace detail
} // namespace pybind11
TEST_SUBMODULE(kwargs_and_defaults, m) { TEST_SUBMODULE(kwargs_and_defaults, m) {
auto kw_func auto kw_func
= [](int x, int y) { return "x=" + std::to_string(x) + ", y=" + std::to_string(y); }; = [](int x, int y) { return "x=" + std::to_string(x) + ", y=" + std::to_string(y); };
@ -322,4 +342,10 @@ TEST_SUBMODULE(kwargs_and_defaults, m) {
py::pos_only{}, py::pos_only{},
py::arg("i"), py::arg("i"),
py::arg("j")); py::arg("j"));
// Test support for args and kwargs subclasses
m.def("args_kwargs_subclass_function",
[](const ArgsSubclass &args, const KWArgsSubclass &kwargs) {
return py::make_tuple(args, kwargs);
});
} }

View File

@ -17,6 +17,10 @@ def test_function_signatures(doc):
assert ( assert (
doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple" doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple"
) )
assert (
doc(m.args_kwargs_subclass_function)
== "args_kwargs_subclass_function(*Args, **KWArgs) -> tuple"
)
assert ( assert (
doc(m.KWClass.foo0) doc(m.KWClass.foo0)
== "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None" == "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None"
@ -98,6 +102,7 @@ def test_arg_and_kwargs():
args = "a1", "a2" args = "a1", "a2"
kwargs = {"arg3": "a3", "arg4": 4} kwargs = {"arg3": "a3", "arg4": 4}
assert m.args_kwargs_function(*args, **kwargs) == (args, kwargs) assert m.args_kwargs_function(*args, **kwargs) == (args, kwargs)
assert m.args_kwargs_subclass_function(*args, **kwargs) == (args, kwargs)
def test_mixed_args_and_kwargs(msg): def test_mixed_args_and_kwargs(msg):
@ -413,6 +418,12 @@ def test_args_refcount():
) )
assert refcount(myval) == expected assert refcount(myval) == expected
assert m.args_kwargs_subclass_function(7, 8, myval, a=1, b=myval) == (
(7, 8, myval),
{"a": 1, "b": myval},
)
assert refcount(myval) == expected
exp3 = refcount(myval, myval, myval) exp3 = refcount(myval, myval, myval)
assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3) assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3)
assert refcount(myval) == expected assert refcount(myval) == expected