mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 05:05:11 +00:00
ability to prevent force casts in numpy arguments
This commit is contained in:
parent
93a317eca1
commit
b47a9de035
@ -1100,10 +1100,12 @@ or ``py::array::f_style``.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
void f(py::array_t<double, py::array::c_style> array);
|
||||
void f(py::array_t<double, py::array::c_style | py::array::forcecast> array);
|
||||
|
||||
As before, the implementation will attempt to convert non-conforming arguments
|
||||
into an array satisfying the specified requirements.
|
||||
The ``py::array::forcecast`` argument is the default value of the second
|
||||
template paramenter, and it ensures that non-conforming arguments are converted
|
||||
into an array satisfying the specified requirements instead of trying the next
|
||||
function overload.
|
||||
|
||||
Vectorizing functions
|
||||
=====================
|
||||
|
@ -33,4 +33,9 @@ void init_ex10(py::module &m) {
|
||||
|
||||
// Vectorize a complex-valued function
|
||||
m.def("vectorized_func3", py::vectorize(my_func3));
|
||||
|
||||
/// Numpy function which only accepts specific data types
|
||||
m.def("selective_func", [](py::array_t<int, py::array::c_style>) { std::cout << "Int branch taken. "<< std::endl; });
|
||||
m.def("selective_func", [](py::array_t<float, py::array::c_style>) { std::cout << "Float branch taken. "<< std::endl; });
|
||||
m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { std::cout << "Complex float branch taken. "<< std::endl; });
|
||||
}
|
||||
|
@ -27,3 +27,8 @@ for f in [vectorized_func, vectorized_func2]:
|
||||
print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2))
|
||||
print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([[2], [3]])* 2)
|
||||
|
||||
from example import selective_func
|
||||
|
||||
selective_func(np.array([1], dtype=np.int32))
|
||||
selective_func(np.array([1.0], dtype=np.float32))
|
||||
selective_func(np.array([1.0j], dtype=np.complex64))
|
||||
|
@ -73,3 +73,6 @@ my_func(x:int=6, y:float=3, z:float=2)
|
||||
[ 24. 30. 36.]]
|
||||
[[ 4 8 12]
|
||||
[24 30 36]]
|
||||
Int branch taken.
|
||||
Float branch taken.
|
||||
Complex float branch taken.
|
||||
|
@ -78,7 +78,8 @@ public:
|
||||
|
||||
enum {
|
||||
c_style = API::NPY_C_CONTIGUOUS_,
|
||||
f_style = API::NPY_F_CONTIGUOUS_
|
||||
f_style = API::NPY_F_CONTIGUOUS_,
|
||||
forcecast = API::NPY_ARRAY_FORCECAST_
|
||||
};
|
||||
|
||||
template <typename Type> array(size_t size, const Type *ptr) {
|
||||
@ -124,7 +125,7 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ExtraFlags = 0> class array_t : public array {
|
||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||
public:
|
||||
PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
|
||||
array_t() : array() { }
|
||||
@ -135,10 +136,9 @@ public:
|
||||
return nullptr;
|
||||
API &api = lookup_api();
|
||||
PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor<T>::value);
|
||||
PyObject *result = api.PyArray_FromAny_(
|
||||
ptr, descr, 0, 0,
|
||||
API::NPY_ENSURE_ARRAY_ | API::NPY_ARRAY_FORCECAST_ | ExtraFlags,
|
||||
nullptr);
|
||||
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||
if (!result)
|
||||
PyErr_Clear();
|
||||
Py_DECREF(ptr);
|
||||
return result;
|
||||
}
|
||||
@ -318,11 +318,11 @@ struct vectorize_helper {
|
||||
template <typename T>
|
||||
vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
|
||||
|
||||
object operator()(array_t<Args>... args) {
|
||||
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
|
||||
return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
|
||||
}
|
||||
|
||||
template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...> index) {
|
||||
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
|
||||
/* Request buffers from all parameters */
|
||||
const size_t N = sizeof...(Args);
|
||||
|
||||
@ -332,7 +332,7 @@ struct vectorize_helper {
|
||||
int ndim = 0;
|
||||
std::vector<size_t> shape(0);
|
||||
bool trivial_broadcast = broadcast(buffers, ndim, shape);
|
||||
|
||||
|
||||
size_t size = 1;
|
||||
std::vector<size_t> strides(ndim);
|
||||
if (ndim > 0) {
|
||||
@ -384,7 +384,7 @@ struct vectorize_helper {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct handle_type_name<array_t<T>> {
|
||||
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
|
||||
static PYBIND11_DESCR name() { return _("numpy.ndarray[dtype=") + type_caster<T>::name() + _("]"); }
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user