mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 05:05:11 +00:00
switch NumPy array to object API, avoid unnecessary copy operation in vectorize
This commit is contained in:
parent
87dfad6544
commit
87187afe91
@ -81,14 +81,13 @@ public:
|
||||
if (descr == nullptr)
|
||||
throw std::runtime_error("NumPy: unsupported buffer format!");
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
PyObject *tmp = api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
|
||||
if (tmp == nullptr)
|
||||
object tmp = object(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
||||
if (ptr && tmp)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
if (!tmp)
|
||||
throw std::runtime_error("NumPy: unable to create array!");
|
||||
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */);
|
||||
Py_DECREF(tmp);
|
||||
if (m_ptr == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to copy array!");
|
||||
m_ptr = tmp.release();
|
||||
}
|
||||
|
||||
array(const buffer_info &info) {
|
||||
@ -102,15 +101,14 @@ public:
|
||||
PyObject *descr = api.PyArray_DescrFromType_(fmt);
|
||||
if (descr == nullptr)
|
||||
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||
PyObject *tmp = api.PyArray_NewFromDescr_(
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
|
||||
if (tmp == nullptr)
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
||||
if (info.ptr && tmp)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
if (!tmp)
|
||||
throw std::runtime_error("NumPy: unable to create array!");
|
||||
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */);
|
||||
Py_DECREF(tmp);
|
||||
if (m_ptr == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to copy array!");
|
||||
m_ptr = tmp.release();
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -184,25 +182,27 @@ struct vectorize_helper {
|
||||
}
|
||||
|
||||
/* Check if the parameters are actually compatible */
|
||||
for (size_t i=0; i<N; ++i) {
|
||||
for (size_t i=0; i<N; ++i)
|
||||
if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
|
||||
throw std::runtime_error("pybind11::vectorize: incompatible size/dimension of inputs!");
|
||||
}
|
||||
|
||||
/* Call the function */
|
||||
std::vector<Return> result(count);
|
||||
for (size_t i=0; i<count; ++i)
|
||||
result[i] = f((buffers[Index].count == 1
|
||||
? *((Args *) buffers[Index].ptr)
|
||||
: ((Args *) buffers[Index].ptr)[i])...);
|
||||
|
||||
if (count == 1)
|
||||
return cast(result[0]);
|
||||
return cast(f(*((Args *) buffers[Index].ptr)...));
|
||||
|
||||
/* Return the result */
|
||||
return array(buffer_info(result.data(), sizeof(Return),
|
||||
array result(buffer_info(nullptr, sizeof(Return),
|
||||
format_descriptor<Return>::value(),
|
||||
ndim, shape, strides));
|
||||
|
||||
buffer_info buf = result.request();
|
||||
Return *output = (Return *) buf.ptr;
|
||||
|
||||
/* Call the function */
|
||||
for (size_t i=0; i<count; ++i)
|
||||
output[i] = f((buffers[Index].count == 1
|
||||
? *((Args *) buffers[Index].ptr)
|
||||
: ((Args *) buffers[Index].ptr)[i])...);
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user