Fix buffer protocol inheritance

Fixes #878.
This commit is contained in:
Dean Moldovan 2017-05-28 16:35:02 +02:00
parent 6d2411f1ac
commit 427e4afc69
3 changed files with 28 additions and 2 deletions

View File

@ -447,11 +447,17 @@ inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
/// buffer_protocol: Fill in the view as specified by flags.
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
auto tinfo = get_type_info(Py_TYPE(obj));
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
type_info *tinfo = nullptr;
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
tinfo = get_type_info((PyTypeObject *) type.ptr());
if (tinfo && tinfo->get_buffer)
break;
}
if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
if (view)
view->obj = nullptr;
PyErr_SetString(PyExc_BufferError, "generic_type::getbuffer(): Internal error");
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
return -1;
}
memset(view, 0, sizeof(Py_buffer));

View File

@ -74,6 +74,11 @@ private:
float *m_data;
};
class SquareMatrix : public Matrix {
public:
SquareMatrix(ssize_t n) : Matrix(n, n) { }
};
struct PTMFBuffer {
int32_t value = 0;
@ -141,6 +146,10 @@ test_initializer buffers([](py::module &m) {
})
;
// Derived classes inherit the buffer protocol and the buffer access function
py::class_<SquareMatrix, Matrix>(m, "SquareMatrix")
.def(py::init<ssize_t>());
py::class_<PTMFBuffer>(m, "PTMFBuffer", py::buffer_protocol())
.def(py::init<>())
.def_readwrite("value", &PTMFBuffer::value)

View File

@ -36,6 +36,7 @@ def test_from_python():
@pytest.unsupported_on_pypy
def test_to_python():
m = Matrix(5, 5)
assert memoryview(m).shape == (5, 5)
assert m[2, 3] == 0
m[2, 3] = 4
@ -63,6 +64,16 @@ def test_to_python():
assert cstats.move_assignments == 0
@pytest.unsupported_on_pypy
def test_inherited_protocol():
"""SquareMatrix is derived from Matrix and inherits the buffer protocol"""
from pybind11_tests import SquareMatrix
matrix = SquareMatrix(5)
assert memoryview(matrix).shape == (5, 5)
assert np.asarray(matrix).shape == (5, 5)
@pytest.unsupported_on_pypy
def test_ptmf():
for cls in [PTMFBuffer, ConstPTMFBuffer, DerivedPTMFBuffer]: