diff --git a/docs/advanced.rst b/docs/advanced.rst index 99af03976..c98692161 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -1100,10 +1100,12 @@ or ``py::array::f_style``. .. code-block:: cpp - void f(py::array_t array); + void f(py::array_t array); -As before, the implementation will attempt to convert non-conforming arguments -into an array satisfying the specified requirements. +The ``py::array::forcecast`` argument is the default value of the second +template paramenter, and it ensures that non-conforming arguments are converted +into an array satisfying the specified requirements instead of trying the next +function overload. Vectorizing functions ===================== diff --git a/example/example10.cpp b/example/example10.cpp index 09769fee8..cbe737e72 100644 --- a/example/example10.cpp +++ b/example/example10.cpp @@ -33,4 +33,9 @@ void init_ex10(py::module &m) { // Vectorize a complex-valued function m.def("vectorized_func3", py::vectorize(my_func3)); + + /// Numpy function which only accepts specific data types + m.def("selective_func", [](py::array_t) { std::cout << "Int branch taken. "<< std::endl; }); + m.def("selective_func", [](py::array_t) { std::cout << "Float branch taken. "<< std::endl; }); + m.def("selective_func", [](py::array_t, py::array::c_style>) { std::cout << "Complex float branch taken. "<< std::endl; }); } diff --git a/example/example10.py b/example/example10.py index 0d49fcaa7..b18e729a6 100755 --- a/example/example10.py +++ b/example/example10.py @@ -27,3 +27,8 @@ for f in [vectorized_func, vectorized_func2]: print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2)) print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([[2], [3]])* 2) +from example import selective_func + +selective_func(np.array([1], dtype=np.int32)) +selective_func(np.array([1.0], dtype=np.float32)) +selective_func(np.array([1.0j], dtype=np.complex64)) diff --git a/example/example10.ref b/example/example10.ref index 9d48d7cfd..4885fc1ca 100644 --- a/example/example10.ref +++ b/example/example10.ref @@ -73,3 +73,6 @@ my_func(x:int=6, y:float=3, z:float=2) [ 24. 30. 36.]] [[ 4 8 12] [24 30 36]] +Int branch taken. +Float branch taken. +Complex float branch taken. diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0b0e0eef6..f97c790ad 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -78,7 +78,8 @@ public: enum { c_style = API::NPY_C_CONTIGUOUS_, - f_style = API::NPY_F_CONTIGUOUS_ + f_style = API::NPY_F_CONTIGUOUS_, + forcecast = API::NPY_ARRAY_FORCECAST_ }; template array(size_t size, const Type *ptr) { @@ -124,7 +125,7 @@ protected: } }; -template class array_t : public array { +template class array_t : public array { public: PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr)); array_t() : array() { } @@ -135,10 +136,9 @@ public: return nullptr; API &api = lookup_api(); PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor::value); - PyObject *result = api.PyArray_FromAny_( - ptr, descr, 0, 0, - API::NPY_ENSURE_ARRAY_ | API::NPY_ARRAY_FORCECAST_ | ExtraFlags, - nullptr); + PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + if (!result) + PyErr_Clear(); Py_DECREF(ptr); return result; } @@ -318,11 +318,11 @@ struct vectorize_helper { template vectorize_helper(T&&f) : f(std::forward(f)) { } - object operator()(array_t... args) { + object operator()(array_t... args) { return run(args..., typename make_index_sequence::type()); } - template object run(array_t&... args, index_sequence index) { + template object run(array_t&... args, index_sequence index) { /* Request buffers from all parameters */ const size_t N = sizeof...(Args); @@ -332,7 +332,7 @@ struct vectorize_helper { int ndim = 0; std::vector shape(0); bool trivial_broadcast = broadcast(buffers, ndim, shape); - + size_t size = 1; std::vector strides(ndim); if (ndim > 0) { @@ -384,7 +384,7 @@ struct vectorize_helper { } }; -template struct handle_type_name> { +template struct handle_type_name> { static PYBIND11_DESCR name() { return _("numpy.ndarray[dtype=") + type_caster::name() + _("]"); } };