Add `type_caster_std_function_specializations` feature. (#4597)

* Allow specializations based on callback function return values.

* clang-tidy auto fix

* Add a test case for function specialization.

* Add test for callback function that raises Python exception.

* Fix test failures.

* style: pre-commit fixes

* Add `#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS`

---------

Co-authored-by: Ralf W. Grosse-Kunstleve <rwgk@google.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Xiaofei Wang 2024-08-10 04:28:12 +08:00 committed by GitHub
parent 20551ab3d8
commit 898794488a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 34 deletions

View File

@ -9,12 +9,55 @@
#pragma once #pragma once
#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS
#include "pybind11.h" #include "pybind11.h"
#include <functional> #include <functional>
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail) PYBIND11_NAMESPACE_BEGIN(detail)
PYBIND11_NAMESPACE_BEGIN(type_caster_std_function_specializations)
// ensure GIL is held during functor destruction
struct func_handle {
function f;
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
// This triggers a syntax error under very special conditions (very weird indeed).
explicit
#endif
func_handle(function &&f_) noexcept
: f(std::move(f_)) {
}
func_handle(const func_handle &f_) { operator=(f_); }
func_handle &operator=(const func_handle &f_) {
gil_scoped_acquire acq;
f = f_.f;
return *this;
}
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
// to emulate 'move initialization capture' in C++11
struct func_wrapper_base {
func_handle hfunc;
explicit func_wrapper_base(func_handle &&hf) noexcept : hfunc(hf) {}
};
template <typename Return, typename... Args>
struct func_wrapper : func_wrapper_base {
using func_wrapper_base::func_wrapper_base;
Return operator()(Args... args) const {
gil_scoped_acquire acq;
// casts the returned object as a rvalue to the return type
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
}
};
PYBIND11_NAMESPACE_END(type_caster_std_function_specializations)
template <typename Return, typename... Args> template <typename Return, typename... Args>
struct type_caster<std::function<Return(Args...)>> { struct type_caster<std::function<Return(Args...)>> {
@ -77,40 +120,8 @@ public:
// See PR #1413 for full details // See PR #1413 for full details
} }
// ensure GIL is held during functor destruction value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
struct func_handle { type_caster_std_function_specializations::func_handle(std::move(func)));
function f;
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
// This triggers a syntax error under very special conditions (very weird indeed).
explicit
#endif
func_handle(function &&f_) noexcept
: f(std::move(f_)) {
}
func_handle(const func_handle &f_) { operator=(f_); }
func_handle &operator=(const func_handle &f_) {
gil_scoped_acquire acq;
f = f_.f;
return *this;
}
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
// to emulate 'move initialization capture' in C++11
struct func_wrapper {
func_handle hfunc;
explicit func_wrapper(func_handle &&hf) noexcept : hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
// casts the returned object as a rvalue to the return type
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
}
};
value = func_wrapper(func_handle(std::move(func)));
return true; return true;
} }

View File

@ -158,6 +158,7 @@ set(PYBIND11_TEST_FILES
test_tagbased_polymorphic test_tagbased_polymorphic
test_thread test_thread
test_type_caster_pyobject_ptr test_type_caster_pyobject_ptr
test_type_caster_std_function_specializations
test_union test_union
test_unnamed_namespace_a test_unnamed_namespace_a
test_unnamed_namespace_b test_unnamed_namespace_b

View File

@ -0,0 +1,46 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include "pybind11_tests.h"
namespace py = pybind11;
namespace {
struct SpecialReturn {
int value = 99;
};
} // namespace
namespace pybind11 {
namespace detail {
namespace type_caster_std_function_specializations {
template <typename... Args>
struct func_wrapper<SpecialReturn, Args...> : func_wrapper_base {
using func_wrapper_base::func_wrapper_base;
SpecialReturn operator()(Args... args) const {
gil_scoped_acquire acq;
SpecialReturn result;
try {
result = hfunc.f(std::forward<Args>(args)...).template cast<SpecialReturn>();
} catch (error_already_set &) {
result.value += 1;
}
result.value += 100;
return result;
}
};
} // namespace type_caster_std_function_specializations
} // namespace detail
} // namespace pybind11
TEST_SUBMODULE(type_caster_std_function_specializations, m) {
py::class_<SpecialReturn>(m, "SpecialReturn")
.def(py::init<>())
.def_readwrite("value", &SpecialReturn::value);
m.def("call_callback_with_special_return",
[](const std::function<SpecialReturn()> &func) { return func(); });
}

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from pybind11_tests import type_caster_std_function_specializations as m
def test_callback_with_special_return():
def return_special():
return m.SpecialReturn()
def raise_exception():
raise ValueError("called raise_exception.")
assert return_special().value == 99
assert m.call_callback_with_special_return(return_special).value == 199
assert m.call_callback_with_special_return(raise_exception).value == 200