mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 14:45:12 +00:00
Add type_caster_std_function_specializations
feature. (#4597)
* Allow specializations based on callback function return values. * clang-tidy auto fix * Add a test case for function specialization. * Add test for callback function that raises Python exception. * Fix test failures. * style: pre-commit fixes * Add `#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS` --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rwgk@google.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
20551ab3d8
commit
898794488a
@ -9,12 +9,55 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS
|
||||||
|
|
||||||
#include "pybind11.h"
|
#include "pybind11.h"
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||||
PYBIND11_NAMESPACE_BEGIN(detail)
|
PYBIND11_NAMESPACE_BEGIN(detail)
|
||||||
|
PYBIND11_NAMESPACE_BEGIN(type_caster_std_function_specializations)
|
||||||
|
|
||||||
|
// ensure GIL is held during functor destruction
|
||||||
|
struct func_handle {
|
||||||
|
function f;
|
||||||
|
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
|
||||||
|
// This triggers a syntax error under very special conditions (very weird indeed).
|
||||||
|
explicit
|
||||||
|
#endif
|
||||||
|
func_handle(function &&f_) noexcept
|
||||||
|
: f(std::move(f_)) {
|
||||||
|
}
|
||||||
|
func_handle(const func_handle &f_) { operator=(f_); }
|
||||||
|
func_handle &operator=(const func_handle &f_) {
|
||||||
|
gil_scoped_acquire acq;
|
||||||
|
f = f_.f;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
~func_handle() {
|
||||||
|
gil_scoped_acquire acq;
|
||||||
|
function kill_f(std::move(f));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// to emulate 'move initialization capture' in C++11
|
||||||
|
struct func_wrapper_base {
|
||||||
|
func_handle hfunc;
|
||||||
|
explicit func_wrapper_base(func_handle &&hf) noexcept : hfunc(hf) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Return, typename... Args>
|
||||||
|
struct func_wrapper : func_wrapper_base {
|
||||||
|
using func_wrapper_base::func_wrapper_base;
|
||||||
|
Return operator()(Args... args) const {
|
||||||
|
gil_scoped_acquire acq;
|
||||||
|
// casts the returned object as a rvalue to the return type
|
||||||
|
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
PYBIND11_NAMESPACE_END(type_caster_std_function_specializations)
|
||||||
|
|
||||||
template <typename Return, typename... Args>
|
template <typename Return, typename... Args>
|
||||||
struct type_caster<std::function<Return(Args...)>> {
|
struct type_caster<std::function<Return(Args...)>> {
|
||||||
@ -77,40 +120,8 @@ public:
|
|||||||
// See PR #1413 for full details
|
// See PR #1413 for full details
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure GIL is held during functor destruction
|
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
|
||||||
struct func_handle {
|
type_caster_std_function_specializations::func_handle(std::move(func)));
|
||||||
function f;
|
|
||||||
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
|
|
||||||
// This triggers a syntax error under very special conditions (very weird indeed).
|
|
||||||
explicit
|
|
||||||
#endif
|
|
||||||
func_handle(function &&f_) noexcept
|
|
||||||
: f(std::move(f_)) {
|
|
||||||
}
|
|
||||||
func_handle(const func_handle &f_) { operator=(f_); }
|
|
||||||
func_handle &operator=(const func_handle &f_) {
|
|
||||||
gil_scoped_acquire acq;
|
|
||||||
f = f_.f;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
~func_handle() {
|
|
||||||
gil_scoped_acquire acq;
|
|
||||||
function kill_f(std::move(f));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// to emulate 'move initialization capture' in C++11
|
|
||||||
struct func_wrapper {
|
|
||||||
func_handle hfunc;
|
|
||||||
explicit func_wrapper(func_handle &&hf) noexcept : hfunc(std::move(hf)) {}
|
|
||||||
Return operator()(Args... args) const {
|
|
||||||
gil_scoped_acquire acq;
|
|
||||||
// casts the returned object as a rvalue to the return type
|
|
||||||
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
value = func_wrapper(func_handle(std::move(func)));
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,6 +158,7 @@ set(PYBIND11_TEST_FILES
|
|||||||
test_tagbased_polymorphic
|
test_tagbased_polymorphic
|
||||||
test_thread
|
test_thread
|
||||||
test_type_caster_pyobject_ptr
|
test_type_caster_pyobject_ptr
|
||||||
|
test_type_caster_std_function_specializations
|
||||||
test_union
|
test_union
|
||||||
test_unnamed_namespace_a
|
test_unnamed_namespace_a
|
||||||
test_unnamed_namespace_b
|
test_unnamed_namespace_b
|
||||||
|
46
tests/test_type_caster_std_function_specializations.cpp
Normal file
46
tests/test_type_caster_std_function_specializations.cpp
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#include <pybind11/functional.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#include "pybind11_tests.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct SpecialReturn {
|
||||||
|
int value = 99;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace pybind11 {
|
||||||
|
namespace detail {
|
||||||
|
namespace type_caster_std_function_specializations {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
struct func_wrapper<SpecialReturn, Args...> : func_wrapper_base {
|
||||||
|
using func_wrapper_base::func_wrapper_base;
|
||||||
|
SpecialReturn operator()(Args... args) const {
|
||||||
|
gil_scoped_acquire acq;
|
||||||
|
SpecialReturn result;
|
||||||
|
try {
|
||||||
|
result = hfunc.f(std::forward<Args>(args)...).template cast<SpecialReturn>();
|
||||||
|
} catch (error_already_set &) {
|
||||||
|
result.value += 1;
|
||||||
|
}
|
||||||
|
result.value += 100;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace type_caster_std_function_specializations
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace pybind11
|
||||||
|
|
||||||
|
TEST_SUBMODULE(type_caster_std_function_specializations, m) {
|
||||||
|
py::class_<SpecialReturn>(m, "SpecialReturn")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def_readwrite("value", &SpecialReturn::value);
|
||||||
|
m.def("call_callback_with_special_return",
|
||||||
|
[](const std::function<SpecialReturn()> &func) { return func(); });
|
||||||
|
}
|
15
tests/test_type_caster_std_function_specializations.py
Normal file
15
tests/test_type_caster_std_function_specializations.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pybind11_tests import type_caster_std_function_specializations as m
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_with_special_return():
|
||||||
|
def return_special():
|
||||||
|
return m.SpecialReturn()
|
||||||
|
|
||||||
|
def raise_exception():
|
||||||
|
raise ValueError("called raise_exception.")
|
||||||
|
|
||||||
|
assert return_special().value == 99
|
||||||
|
assert m.call_callback_with_special_return(return_special).value == 199
|
||||||
|
assert m.call_callback_with_special_return(raise_exception).value == 200
|
Loading…
Reference in New Issue
Block a user