mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
Compare commits
5 Commits
dadbee06a5
...
fae8bd9f8d
Author | SHA1 | Date | |
---|---|---|---|
|
fae8bd9f8d | ||
|
1f8b4a7f1a | ||
|
e0be5dbd48 | ||
|
d21cee39e8 | ||
|
0235533fda |
@ -14,6 +14,7 @@
|
|||||||
#include "pybind11.h"
|
#include "pybind11.h"
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||||
PYBIND11_NAMESPACE_BEGIN(detail)
|
PYBIND11_NAMESPACE_BEGIN(detail)
|
||||||
@ -101,8 +102,17 @@ public:
|
|||||||
if (detail::is_function_record_capsule(c)) {
|
if (detail::is_function_record_capsule(c)) {
|
||||||
rec = c.get_pointer<function_record>();
|
rec = c.get_pointer<function_record>();
|
||||||
}
|
}
|
||||||
|
|
||||||
while (rec != nullptr) {
|
while (rec != nullptr) {
|
||||||
|
const size_t self_offset = rec->is_method ? 1 : 0;
|
||||||
|
if (rec->nargs != sizeof...(Args) + self_offset) {
|
||||||
|
rec = rec->next;
|
||||||
|
// if the overload is not feasible in terms of number of arguments, we
|
||||||
|
// continue to the next one. If there is no next one, we return false.
|
||||||
|
if (rec == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (rec->is_stateless
|
if (rec->is_stateless
|
||||||
&& same_type(typeid(function_type),
|
&& same_type(typeid(function_type),
|
||||||
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
|
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
|
||||||
@ -118,6 +128,38 @@ public:
|
|||||||
// PYPY segfaults here when passing builtin function like sum.
|
// PYPY segfaults here when passing builtin function like sum.
|
||||||
// Raising an fail exception here works to prevent the segfault, but only on gcc.
|
// Raising an fail exception here works to prevent the segfault, but only on gcc.
|
||||||
// See PR #1413 for full details
|
// See PR #1413 for full details
|
||||||
|
} else {
|
||||||
|
// Check number of arguments of Python function
|
||||||
|
auto get_argument_count = [](const handle &obj) -> size_t {
|
||||||
|
// Faster then `import inspect` and `inspect.signature(obj).parameters`
|
||||||
|
return obj.attr("co_argcount").cast<size_t>();
|
||||||
|
};
|
||||||
|
size_t argCount = 0;
|
||||||
|
|
||||||
|
handle empty;
|
||||||
|
object codeAttr = getattr(src, "__code__", empty);
|
||||||
|
|
||||||
|
if (codeAttr) {
|
||||||
|
argCount = get_argument_count(codeAttr);
|
||||||
|
} else {
|
||||||
|
object callAttr = getattr(src, "__call__", empty);
|
||||||
|
|
||||||
|
if (callAttr) {
|
||||||
|
object codeAttr2 = getattr(callAttr, "__code__");
|
||||||
|
argCount = get_argument_count(codeAttr2) - 1; // removing the self argument
|
||||||
|
} else {
|
||||||
|
// No __code__ or __call__ attribute, this is not a proper Python function
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if we are a method, we have to correct the argument count since we are not counting
|
||||||
|
// the self argument
|
||||||
|
const size_t self_offset = static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
|
||||||
|
|
||||||
|
argCount -= self_offset;
|
||||||
|
if (argCount != sizeof...(Args)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
|
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
|
||||||
|
@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) {
|
|||||||
return "argument does NOT match dummy_function. This should never happen!";
|
return "argument does NOT match dummy_function. This should never happen!";
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// test_cpp_correct_overload_resolution
|
||||||
|
m.def("dummy_function_overloaded_std_func_arg",
|
||||||
|
[](const std::function<int(int)> &f) { return 3 * f(3); });
|
||||||
|
m.def("dummy_function_overloaded_std_func_arg",
|
||||||
|
[](const std::function<int(int, int)> &f) { return 2 * f(3, 4); });
|
||||||
|
|
||||||
class AbstractBase {
|
class AbstractBase {
|
||||||
public:
|
public:
|
||||||
// [workaround(intel)] = default does not work here
|
// [workaround(intel)] = default does not work here
|
||||||
|
@ -103,6 +103,31 @@ def test_cpp_callable_cleanup():
|
|||||||
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]
|
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpp_correct_overload_resolution():
|
||||||
|
def f(a):
|
||||||
|
return a
|
||||||
|
|
||||||
|
class A:
|
||||||
|
def __call__(self, a):
|
||||||
|
return a
|
||||||
|
|
||||||
|
assert m.dummy_function_overloaded_std_func_arg(f) == 9
|
||||||
|
a = A()
|
||||||
|
assert m.dummy_function_overloaded_std_func_arg(a) == 9
|
||||||
|
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9
|
||||||
|
|
||||||
|
def f2(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
class B:
|
||||||
|
def __call__(self, a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
assert m.dummy_function_overloaded_std_func_arg(f2) == 14
|
||||||
|
assert m.dummy_function_overloaded_std_func_arg(B()) == 14
|
||||||
|
assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_function_roundtrip():
|
def test_cpp_function_roundtrip():
|
||||||
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""
|
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""
|
||||||
|
|
||||||
@ -131,7 +156,10 @@ def test_cpp_function_roundtrip():
|
|||||||
m.test_dummy_function(lambda x, y: x + y)
|
m.test_dummy_function(lambda x, y: x + y)
|
||||||
assert any(
|
assert any(
|
||||||
s in str(excinfo.value)
|
s in str(excinfo.value)
|
||||||
for s in ("missing 1 required positional argument", "takes exactly 2 arguments")
|
for s in (
|
||||||
|
"incompatible function arguments. The following argument types are",
|
||||||
|
"function test_cpp_function_roundtrip.<locals>.<lambda>",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include <pybind11/embed.h>
|
#include <pybind11/embed.h>
|
||||||
|
#include <pybind11/functional.h>
|
||||||
|
|
||||||
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
|
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
|
||||||
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
|
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
|
||||||
@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
|
|||||||
d["missing"].cast<py::object>();
|
d["missing"].cast<py::object>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PYBIND11_EMBEDDED_MODULE(func_module, m) {
|
||||||
|
m.def("funcOverload", [](const std::function<int(int, int)> &f) {
|
||||||
|
return f(2, 3);
|
||||||
|
}).def("funcOverload", [](const std::function<int(int)> &f) { return f(2); });
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("PYTHONPATH is used to update sys.path") {
|
TEST_CASE("PYTHONPATH is used to update sys.path") {
|
||||||
// The setup for this TEST_CASE is in catch.cpp!
|
// The setup for this TEST_CASE is in catch.cpp!
|
||||||
auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
|
auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
|
||||||
@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") {
|
|||||||
py::initialize_interpreter();
|
py::initialize_interpreter();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("Check the overload resolution from cpp_function objects to std::function") {
|
||||||
|
auto m = py::module_::import("func_module");
|
||||||
|
auto f = std::function<int(int)>([](int x) { return 2 * x; });
|
||||||
|
REQUIRE(m.attr("funcOverload")(f).template cast<int>() == 4);
|
||||||
|
|
||||||
|
auto f2 = std::function<int(int, int)>([](int x, int y) { return 2 * x * y; });
|
||||||
|
REQUIRE(m.attr("funcOverload")(f2).template cast<int>() == 12);
|
||||||
|
}
|
||||||
|
|
||||||
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
|
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
|
||||||
TEST_CASE("Custom PyConfig") {
|
TEST_CASE("Custom PyConfig") {
|
||||||
py::finalize_interpreter();
|
py::finalize_interpreter();
|
||||||
|
@ -274,10 +274,6 @@ function(pybind11_add_module target_name)
|
|||||||
target_link_libraries(${target_name} PRIVATE pybind11::embed)
|
target_link_libraries(${target_name} PRIVATE pybind11::embed)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MSVC)
|
|
||||||
target_link_libraries(${target_name} PRIVATE pybind11::windows_extras)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# -fvisibility=hidden is required to allow multiple modules compiled against
|
# -fvisibility=hidden is required to allow multiple modules compiled against
|
||||||
# different pybind versions to work properly, and for some features (e.g.
|
# different pybind versions to work properly, and for some features (e.g.
|
||||||
# py::module_local). We force it on everything inside the `pybind11`
|
# py::module_local). We force it on everything inside the `pybind11`
|
||||||
|
Loading…
Reference in New Issue
Block a user