diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index e3c8b9f65..ebbe6c364 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1570,9 +1570,9 @@ class argument_loader { using indices = make_index_sequence; template - using argument_is_args = std::is_base_of>; + using argument_is_args = all_of>, negation>>; template - using argument_is_kwargs = std::is_base_of>; + using argument_is_kwargs = all_of>, negation>>; // Get kwargs argument position, or -1 if not present: static constexpr auto kwargs_pos = constexpr_last(); diff --git a/tests/test_kwargs_and_defaults.cpp b/tests/test_kwargs_and_defaults.cpp index 09036ccd5..a1a57ded6 100644 --- a/tests/test_kwargs_and_defaults.cpp +++ b/tests/test_kwargs_and_defaults.cpp @@ -21,6 +21,33 @@ class ArgsSubclass : public py::args { class KWArgsSubclass : public py::kwargs { using py::kwargs::kwargs; }; +class MoveOrCopyInt { +public: + MoveOrCopyInt() { print_default_created(this); } + explicit MoveOrCopyInt(int v) : value{v} { print_created(this, value); } + MoveOrCopyInt(MoveOrCopyInt &&m) noexcept { + print_move_created(this, m.value); + std::swap(value, m.value); + } + MoveOrCopyInt &operator=(MoveOrCopyInt &&m) noexcept { + print_move_assigned(this, m.value); + std::swap(value, m.value); + return *this; + } + MoveOrCopyInt(const MoveOrCopyInt &c) { + print_copy_created(this, c.value); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + value = c.value; + } + MoveOrCopyInt &operator=(const MoveOrCopyInt &c) { + print_copy_assigned(this, c.value); + value = c.value; + return *this; + } + ~MoveOrCopyInt() { print_destroyed(this); } + + int value; +}; namespace pybind11 { namespace detail { template <> @@ -31,6 +58,19 @@ template <> struct handle_type_name { static constexpr auto name = const_name("**KWArgs"); }; +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MoveOrCopyInt*, const_name("MoveOrCopyInt")); + bool load(handle src, bool) { + auto as_class = MoveOrCopyInt(src.cast()); + value = &as_class; + return true; + } + static handle cast(int v, return_value_policy r, handle p) { + auto as_class = MoveOrCopyInt(v); + return pybind11::handle(as_class, r, p); + } +}; } // namespace detail } // namespace pybind11 @@ -348,4 +388,10 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { [](const ArgsSubclass &args, const KWArgsSubclass &kwargs) { return py::make_tuple(args, kwargs); }); + + // Test that support for args and kwargs subclasses skips checking arguments passed in as pointers + m.def("args_kwargs_subclass_function_with_pointer_arg", + [](MoveOrCopyInt* pointer, const ArgsSubclass &args, const KWArgsSubclass &kwargs) { + return py::make_tuple(pointer->value, args, kwargs); + }); } diff --git a/tests/test_kwargs_and_defaults.py b/tests/test_kwargs_and_defaults.py index e3f758165..8bc086c2c 100644 --- a/tests/test_kwargs_and_defaults.py +++ b/tests/test_kwargs_and_defaults.py @@ -21,6 +21,10 @@ def test_function_signatures(doc): doc(m.args_kwargs_subclass_function) == "args_kwargs_subclass_function(*Args, **KWArgs) -> tuple" ) + assert ( + doc(m.args_kwargs_subclass_function_with_pointer_arg) + == "args_kwargs_subclass_function_with_pointer_arg(arg0: NotArgsOrKWArgsClass, *Args, **KWArgs) -> tuple" + ) assert ( doc(m.KWClass.foo0) == "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None" @@ -103,6 +107,7 @@ def test_arg_and_kwargs(): kwargs = {"arg3": "a3", "arg4": 4} assert m.args_kwargs_function(*args, **kwargs) == (args, kwargs) assert m.args_kwargs_subclass_function(*args, **kwargs) == (args, kwargs) + assert m.args_kwargs_subclass_function_with_pointer_arg(10, *args, **kwargs) == (10, args, kwargs) def test_mixed_args_and_kwargs(msg): @@ -424,6 +429,13 @@ def test_args_refcount(): ) assert refcount(myval) == expected + assert m.args_kwargs_subclass_function_with_pointer_arg(7, 8, myval, a=1, b=myval) == ( + 7, + (8, myval), + {"a": 1, "b": myval}, + ) + assert refcount(myval) == expected + exp3 = refcount(myval, myval, myval) assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3) assert refcount(myval) == expected