diff --git a/include/pybind11/detail/common.h b/include/pybind11/detail/common.h index 1d4939bed..ea09bb3fd 100644 --- a/include/pybind11/detail/common.h +++ b/include/pybind11/detail/common.h @@ -936,9 +936,11 @@ PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybin PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally [[noreturn]] PYBIND11_NOINLINE void pybind11_fail(const char *reason) { + assert(!PyErr_Occurred()); throw std::runtime_error(reason); } [[noreturn]] PYBIND11_NOINLINE void pybind11_fail(const std::string &reason) { + assert(!PyErr_Occurred()); throw std::runtime_error(reason); } diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 29f70bd5f..dc753d32c 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1509,6 +1509,9 @@ public: explicit weakref(handle obj, handle callback = {}) : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) { if (!m_ptr) { + if (PyErr_Occurred()) { + throw error_already_set(); + } pybind11_fail("Could not allocate weak reference!"); } } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 07ace0049..becd1cc8a 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1,8 +1,9 @@ +import contextlib import sys import pytest -import env # noqa: F401 +import env from pybind11_tests import debug_enabled from pybind11_tests import pytypes as m @@ -583,6 +584,31 @@ def test_weakref(create_weakref, create_weakref_with_callback): assert callback_called +@pytest.mark.parametrize( + "create_weakref, has_callback", + [ + (m.weakref_from_handle, False), + (m.weakref_from_object, False), + (m.weakref_from_handle_and_function, True), + (m.weakref_from_object_and_function, True), + ], +) +def test_weakref_err(create_weakref, has_callback): + class C: + __slots__ = [] + + def callback(_): + pass + + ob = C() + # Should raise TypeError on CPython + with pytest.raises(TypeError) if not env.PYPY else contextlib.nullcontext(): + if has_callback: + _ = create_weakref(ob, callback) + else: + _ = create_weakref(ob) + + def test_cpp_iterators(): assert m.tuple_iterator() == 12 assert m.dict_iterator() == 305 + 711