mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-29 16:37:13 +00:00
switch NumPy array to object API, avoid unnecessary copy operation in vectorize
This commit is contained in:
parent
87dfad6544
commit
87187afe91
@ -81,14 +81,13 @@ public:
|
|||||||
if (descr == nullptr)
|
if (descr == nullptr)
|
||||||
throw std::runtime_error("NumPy: unsupported buffer format!");
|
throw std::runtime_error("NumPy: unsupported buffer format!");
|
||||||
Py_intptr_t shape = (Py_intptr_t) size;
|
Py_intptr_t shape = (Py_intptr_t) size;
|
||||||
PyObject *tmp = api.PyArray_NewFromDescr_(
|
object tmp = object(api.PyArray_NewFromDescr_(
|
||||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
|
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
||||||
if (tmp == nullptr)
|
if (ptr && tmp)
|
||||||
|
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||||
|
if (!tmp)
|
||||||
throw std::runtime_error("NumPy: unable to create array!");
|
throw std::runtime_error("NumPy: unable to create array!");
|
||||||
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */);
|
m_ptr = tmp.release();
|
||||||
Py_DECREF(tmp);
|
|
||||||
if (m_ptr == nullptr)
|
|
||||||
throw std::runtime_error("NumPy: unable to copy array!");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array(const buffer_info &info) {
|
array(const buffer_info &info) {
|
||||||
@ -102,15 +101,14 @@ public:
|
|||||||
PyObject *descr = api.PyArray_DescrFromType_(fmt);
|
PyObject *descr = api.PyArray_DescrFromType_(fmt);
|
||||||
if (descr == nullptr)
|
if (descr == nullptr)
|
||||||
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||||
PyObject *tmp = api.PyArray_NewFromDescr_(
|
object tmp(api.PyArray_NewFromDescr_(
|
||||||
api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
|
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
||||||
if (tmp == nullptr)
|
if (info.ptr && tmp)
|
||||||
|
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||||
|
if (!tmp)
|
||||||
throw std::runtime_error("NumPy: unable to create array!");
|
throw std::runtime_error("NumPy: unable to create array!");
|
||||||
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */);
|
m_ptr = tmp.release();
|
||||||
Py_DECREF(tmp);
|
|
||||||
if (m_ptr == nullptr)
|
|
||||||
throw std::runtime_error("NumPy: unable to copy array!");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -184,25 +182,27 @@ struct vectorize_helper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Check if the parameters are actually compatible */
|
/* Check if the parameters are actually compatible */
|
||||||
for (size_t i=0; i<N; ++i) {
|
for (size_t i=0; i<N; ++i)
|
||||||
if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
|
if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
|
||||||
throw std::runtime_error("pybind11::vectorize: incompatible size/dimension of inputs!");
|
throw std::runtime_error("pybind11::vectorize: incompatible size/dimension of inputs!");
|
||||||
}
|
|
||||||
|
|
||||||
/* Call the function */
|
|
||||||
std::vector<Return> result(count);
|
|
||||||
for (size_t i=0; i<count; ++i)
|
|
||||||
result[i] = f((buffers[Index].count == 1
|
|
||||||
? *((Args *) buffers[Index].ptr)
|
|
||||||
: ((Args *) buffers[Index].ptr)[i])...);
|
|
||||||
|
|
||||||
if (count == 1)
|
if (count == 1)
|
||||||
return cast(result[0]);
|
return cast(f(*((Args *) buffers[Index].ptr)...));
|
||||||
|
|
||||||
/* Return the result */
|
array result(buffer_info(nullptr, sizeof(Return),
|
||||||
return array(buffer_info(result.data(), sizeof(Return),
|
|
||||||
format_descriptor<Return>::value(),
|
format_descriptor<Return>::value(),
|
||||||
ndim, shape, strides));
|
ndim, shape, strides));
|
||||||
|
|
||||||
|
buffer_info buf = result.request();
|
||||||
|
Return *output = (Return *) buf.ptr;
|
||||||
|
|
||||||
|
/* Call the function */
|
||||||
|
for (size_t i=0; i<count; ++i)
|
||||||
|
output[i] = f((buffers[Index].count == 1
|
||||||
|
? *((Args *) buffers[Index].ptr)
|
||||||
|
: ((Args *) buffers[Index].ptr)[i])...);
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user