ability to prevent force casts in numpy arguments

This commit is contained in:
Wenzel Jakob 2016-05-19 16:02:09 +02:00
parent 93a317eca1
commit b47a9de035
5 changed files with 28 additions and 13 deletions

View File

@ -1100,10 +1100,12 @@ or ``py::array::f_style``.
.. code-block:: cpp
void f(py::array_t<double, py::array::c_style> array);
void f(py::array_t<double, py::array::c_style | py::array::forcecast> array);
As before, the implementation will attempt to convert non-conforming arguments
into an array satisfying the specified requirements.
The ``py::array::forcecast`` argument is the default value of the second
template paramenter, and it ensures that non-conforming arguments are converted
into an array satisfying the specified requirements instead of trying the next
function overload.
Vectorizing functions
=====================

View File

@ -33,4 +33,9 @@ void init_ex10(py::module &m) {
// Vectorize a complex-valued function
m.def("vectorized_func3", py::vectorize(my_func3));
/// Numpy function which only accepts specific data types
m.def("selective_func", [](py::array_t<int, py::array::c_style>) { std::cout << "Int branch taken. "<< std::endl; });
m.def("selective_func", [](py::array_t<float, py::array::c_style>) { std::cout << "Float branch taken. "<< std::endl; });
m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { std::cout << "Complex float branch taken. "<< std::endl; });
}

View File

@ -27,3 +27,8 @@ for f in [vectorized_func, vectorized_func2]:
print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2))
print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([[2], [3]])* 2)
from example import selective_func
selective_func(np.array([1], dtype=np.int32))
selective_func(np.array([1.0], dtype=np.float32))
selective_func(np.array([1.0j], dtype=np.complex64))

View File

@ -73,3 +73,6 @@ my_func(x:int=6, y:float=3, z:float=2)
[ 24. 30. 36.]]
[[ 4 8 12]
[24 30 36]]
Int branch taken.
Float branch taken.
Complex float branch taken.

View File

@ -78,7 +78,8 @@ public:
enum {
c_style = API::NPY_C_CONTIGUOUS_,
f_style = API::NPY_F_CONTIGUOUS_
f_style = API::NPY_F_CONTIGUOUS_,
forcecast = API::NPY_ARRAY_FORCECAST_
};
template <typename Type> array(size_t size, const Type *ptr) {
@ -124,7 +125,7 @@ protected:
}
};
template <typename T, int ExtraFlags = 0> class array_t : public array {
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
public:
PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
array_t() : array() { }
@ -135,10 +136,9 @@ public:
return nullptr;
API &api = lookup_api();
PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor<T>::value);
PyObject *result = api.PyArray_FromAny_(
ptr, descr, 0, 0,
API::NPY_ENSURE_ARRAY_ | API::NPY_ARRAY_FORCECAST_ | ExtraFlags,
nullptr);
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result)
PyErr_Clear();
Py_DECREF(ptr);
return result;
}
@ -318,11 +318,11 @@ struct vectorize_helper {
template <typename T>
vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
object operator()(array_t<Args>... args) {
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
}
template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...> index) {
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
/* Request buffers from all parameters */
const size_t N = sizeof...(Args);
@ -332,7 +332,7 @@ struct vectorize_helper {
int ndim = 0;
std::vector<size_t> shape(0);
bool trivial_broadcast = broadcast(buffers, ndim, shape);
size_t size = 1;
std::vector<size_t> strides(ndim);
if (ndim > 0) {
@ -384,7 +384,7 @@ struct vectorize_helper {
}
};
template <typename T> struct handle_type_name<array_t<T>> {
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
static PYBIND11_DESCR name() { return _("numpy.ndarray[dtype=") + type_caster<T>::name() + _("]"); }
};