diff --git a/include/pybind11/class_support.h b/include/pybind11/class_support.h index 55e623563..fb73390ae 100644 --- a/include/pybind11/class_support.h +++ b/include/pybind11/class_support.h @@ -447,11 +447,17 @@ inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { /// buffer_protocol: Fill in the view as specified by flags. extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { - auto tinfo = get_type_info(Py_TYPE(obj)); + // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). + type_info *tinfo = nullptr; + for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { + tinfo = get_type_info((PyTypeObject *) type.ptr()); + if (tinfo && tinfo->get_buffer) + break; + } if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) { if (view) view->obj = nullptr; - PyErr_SetString(PyExc_BufferError, "generic_type::getbuffer(): Internal error"); + PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); return -1; } memset(view, 0, sizeof(Py_buffer)); diff --git a/tests/test_buffers.cpp b/tests/test_buffers.cpp index 1a18f7949..cbd234ffa 100644 --- a/tests/test_buffers.cpp +++ b/tests/test_buffers.cpp @@ -74,6 +74,11 @@ private: float *m_data; }; +class SquareMatrix : public Matrix { +public: + SquareMatrix(ssize_t n) : Matrix(n, n) { } +}; + struct PTMFBuffer { int32_t value = 0; @@ -141,6 +146,10 @@ test_initializer buffers([](py::module &m) { }) ; + // Derived classes inherit the buffer protocol and the buffer access function + py::class_(m, "SquareMatrix") + .def(py::init()); + py::class_(m, "PTMFBuffer", py::buffer_protocol()) .def(py::init<>()) .def_readwrite("value", &PTMFBuffer::value) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 66a9909bc..a9374119e 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -36,6 +36,7 @@ def test_from_python(): @pytest.unsupported_on_pypy def test_to_python(): m = Matrix(5, 5) + assert memoryview(m).shape == (5, 5) assert m[2, 3] == 0 m[2, 3] = 4 @@ -63,6 +64,16 @@ def test_to_python(): assert cstats.move_assignments == 0 +@pytest.unsupported_on_pypy +def test_inherited_protocol(): + """SquareMatrix is derived from Matrix and inherits the buffer protocol""" + from pybind11_tests import SquareMatrix + + matrix = SquareMatrix(5) + assert memoryview(matrix).shape == (5, 5) + assert np.asarray(matrix).shape == (5, 5) + + @pytest.unsupported_on_pypy def test_ptmf(): for cls in [PTMFBuffer, ConstPTMFBuffer, DerivedPTMFBuffer]: