From 83a8a977a74203fedaedbf6969ebcc036e86d17c Mon Sep 17 00:00:00 2001 From: Roman Miroshnychenko Date: Sun, 2 Apr 2017 23:38:50 +0300 Subject: [PATCH] Add a method to check Python exception types (#772) This commit adds `error_already_set::matches()` convenience method to check if the exception trapped by `error_already_set` matches a given Python exception type. This will address #700 by providing a less verbose way to check exceptions. --- include/pybind11/common.h | 3 +++ tests/test_exceptions.cpp | 15 +++++++++++++++ tests/test_exceptions.py | 5 +++++ 3 files changed, 23 insertions(+) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index e420f73e9..3d78ae1b4 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -597,6 +597,9 @@ public: /// Clear the held Python error state (the C++ `what()` message remains intact) void clear() { restore(); PyErr_Clear(); } + /// Check if the trapped exception matches a given Python exception class + bool matches(PyObject *ex) const { return PyErr_GivenExceptionMatches(ex, type); } + private: PyObject *type, *value, *trace; }; diff --git a/tests/test_exceptions.cpp b/tests/test_exceptions.cpp index 706b500f2..ea6bdb9f8 100644 --- a/tests/test_exceptions.cpp +++ b/tests/test_exceptions.cpp @@ -86,6 +86,20 @@ void throws_logic_error() { throw std::logic_error("this error should fall through to the standard handler"); } +// Test error_already_set::matches() method +void exception_matches() { + py::dict foo; + try { + foo["bar"]; + } + catch (py::error_already_set& ex) { + if (ex.matches(PyExc_KeyError)) + ex.clear(); + else + throw; + } +} + struct PythonCallInDestructor { PythonCallInDestructor(const py::dict &d) : d(d) {} ~PythonCallInDestructor() { d["good"] = true; } @@ -140,6 +154,7 @@ test_initializer custom_exceptions([](py::module &m) { m.def("throws5", &throws5); m.def("throws5_1", &throws5_1); m.def("throws_logic_error", &throws_logic_error); + m.def("exception_matches", &exception_matches); m.def("throw_already_set", [](bool err) { if (err) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 0025e4eb6..887ba644e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -21,6 +21,11 @@ def test_python_call_in_catch(): assert d["good"] is True +def test_exception_matches(): + from pybind11_tests import exception_matches + exception_matches() + + def test_custom(msg): from pybind11_tests import (MyException, MyException5, MyException5_1, throws1, throws2, throws3, throws4, throws5, throws5_1,