diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 8c8a4620e..2de9bcd20 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -344,7 +344,7 @@ public: /// Check if the currently trapped error type matches the given Python exception class (or a /// subclass thereof). May also be passed a tuple to search for any exception class matches in /// the given tuple. - bool matches(handle ex) const { return PyErr_GivenExceptionMatches(ex.ptr(), m_type.ptr()); } + bool matches(handle exc) const { return PyErr_GivenExceptionMatches(m_type.ptr(), exc.ptr()); } const object& type() const { return m_type; } const object& value() const { return m_value; } diff --git a/tests/test_exceptions.cpp b/tests/test_exceptions.cpp index cf202143d..d30139037 100644 --- a/tests/test_exceptions.cpp +++ b/tests/test_exceptions.cpp @@ -118,10 +118,38 @@ TEST_SUBMODULE(exceptions, m) { m.def("throws_logic_error", []() { throw std::logic_error("this error should fall through to the standard handler"); }); m.def("exception_matches", []() { py::dict foo; - try { foo["bar"]; } + try { + // Assign to a py::object to force read access of nonexistent dict entry + py::object o = foo["bar"]; + } catch (py::error_already_set& ex) { if (!ex.matches(PyExc_KeyError)) throw; + return true; } + return false; + }); + m.def("exception_matches_base", []() { + py::dict foo; + try { + // Assign to a py::object to force read access of nonexistent dict entry + py::object o = foo["bar"]; + } + catch (py::error_already_set &ex) { + if (!ex.matches(PyExc_Exception)) throw; + return true; + } + return false; + }); + m.def("modulenotfound_exception_matches_base", []() { + try { + // On Python >= 3.6, this raises a ModuleNotFoundError, a subclass of ImportError + py::module::import("nonexistent"); + } + catch (py::error_already_set &ex) { + if (!ex.matches(PyExc_ImportError)) throw; + return true; + } + return false; }); m.def("throw_already_set", [](bool err) { diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8d37c09b8..6edff9fe4 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -48,7 +48,9 @@ def test_python_call_in_catch(): def test_exception_matches(): - m.exception_matches() + assert m.exception_matches() + assert m.exception_matches_base() + assert m.modulenotfound_exception_matches_base() def test_custom(msg):