diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index fd8f9a666..1ff698f3b 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1375,6 +1375,11 @@ unpacking_collector collect_arguments(Args &&...args) { template template object object_api::operator()(Args &&...args) const { +#if !defined(NDEBUG) && PY_VERSION_HEX >= 0x03060000 + if (!PyGILState_Check()) { + pybind11_fail("pybind11::object_api<>::operator() PyGILState_Check() failure."); + } +#endif return detail::collect_arguments(std::forward(args)...).call(derived().ptr()); } diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index dffe538fc..61bc3a8f0 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -172,4 +172,10 @@ TEST_SUBMODULE(callbacks, m) { for (auto i : work) start_f(py::cast(i)); }); + + m.def("callback_num_times", [](py::function f, std::size_t num) { + for (std::size_t i = 0; i < num; i++) { + f(); + } + }); } diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 039b877ce..cec68bda5 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -2,6 +2,7 @@ import pytest from pybind11_tests import callbacks as m from threading import Thread +import time def test_callbacks(): @@ -146,3 +147,34 @@ def test_async_async_callbacks(): t = Thread(target=test_async_callbacks) t.start() t.join() + + +def test_callback_num_times(): + # Super-simple micro-benchmarking related to PR #2919. + # Example runtimes (Intel Xeon 2.2GHz, fully optimized): + # num_millions 1, repeats 2: 0.1 secs + # num_millions 20, repeats 10: 11.5 secs + one_million = 1000000 + num_millions = 1 # Try 20 for actual micro-benchmarking. + repeats = 2 # Try 10. + rates = [] + for rep in range(repeats): + t0 = time.time() + m.callback_num_times(lambda: None, num_millions * one_million) + td = time.time() - t0 + rate = num_millions / td if td else 0 + rates.append(rate) + if not rep: + print() + print( + "callback_num_times: {:d} million / {:.3f} seconds = {:.3f} million / second".format( + num_millions, td, rate + ) + ) + if len(rates) > 1: + print("Min Mean Max") + print( + "{:6.3f} {:6.3f} {:6.3f}".format( + min(rates), sum(rates) / len(rates), max(rates) + ) + )