switch NumPy array to object API, avoid unnecessary copy operation in vectorize

This commit is contained in:
Wenzel Jakob 2016-01-17 22:36:39 +01:00
parent 87dfad6544
commit 87187afe91

View File

@ -81,14 +81,13 @@ public:
if (descr == nullptr) if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format!"); throw std::runtime_error("NumPy: unsupported buffer format!");
Py_intptr_t shape = (Py_intptr_t) size; Py_intptr_t shape = (Py_intptr_t) size;
PyObject *tmp = api.PyArray_NewFromDescr_( object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr); api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
if (tmp == nullptr) if (ptr && tmp)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
if (!tmp)
throw std::runtime_error("NumPy: unable to create array!"); throw std::runtime_error("NumPy: unable to create array!");
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */); m_ptr = tmp.release();
Py_DECREF(tmp);
if (m_ptr == nullptr)
throw std::runtime_error("NumPy: unable to copy array!");
} }
array(const buffer_info &info) { array(const buffer_info &info) {
@ -102,15 +101,14 @@ public:
PyObject *descr = api.PyArray_DescrFromType_(fmt); PyObject *descr = api.PyArray_DescrFromType_(fmt);
if (descr == nullptr) if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!"); throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
PyObject *tmp = api.PyArray_NewFromDescr_( object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0], api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr); (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (tmp == nullptr) if (info.ptr && tmp)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
if (!tmp)
throw std::runtime_error("NumPy: unable to create array!"); throw std::runtime_error("NumPy: unable to create array!");
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */); m_ptr = tmp.release();
Py_DECREF(tmp);
if (m_ptr == nullptr)
throw std::runtime_error("NumPy: unable to copy array!");
} }
protected: protected:
@ -184,25 +182,27 @@ struct vectorize_helper {
} }
/* Check if the parameters are actually compatible */ /* Check if the parameters are actually compatible */
for (size_t i=0; i<N; ++i) { for (size_t i=0; i<N; ++i)
if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape)) if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
throw std::runtime_error("pybind11::vectorize: incompatible size/dimension of inputs!"); throw std::runtime_error("pybind11::vectorize: incompatible size/dimension of inputs!");
}
/* Call the function */
std::vector<Return> result(count);
for (size_t i=0; i<count; ++i)
result[i] = f((buffers[Index].count == 1
? *((Args *) buffers[Index].ptr)
: ((Args *) buffers[Index].ptr)[i])...);
if (count == 1) if (count == 1)
return cast(result[0]); return cast(f(*((Args *) buffers[Index].ptr)...));
/* Return the result */ array result(buffer_info(nullptr, sizeof(Return),
return array(buffer_info(result.data(), sizeof(Return),
format_descriptor<Return>::value(), format_descriptor<Return>::value(),
ndim, shape, strides)); ndim, shape, strides));
buffer_info buf = result.request();
Return *output = (Return *) buf.ptr;
/* Call the function */
for (size_t i=0; i<count; ++i)
output[i] = f((buffers[Index].count == 1
? *((Args *) buffers[Index].ptr)
: ((Args *) buffers[Index].ptr)[i])...);
return result;
} }
}; };