mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-16 21:57:55 +00:00
Improve py::array_t scalar type information (#724)
* Add value_type member alias to py::array_t (resolve #632) * Use numpy scalar name in py::array_t function signatures (e.g. float32/64 instead of just float)
This commit is contained in:
parent
dc5ce5930f
commit
16afbcef46
@ -577,6 +577,8 @@ protected:
|
|||||||
|
|
||||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||||
public:
|
public:
|
||||||
|
using value_type = T;
|
||||||
|
|
||||||
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
||||||
array_t(handle h, borrowed_t) : array(h, borrowed) { }
|
array_t(handle h, borrowed_t) : array(h, borrowed) { }
|
||||||
array_t(handle h, stolen_t) : array(h, stolen) { }
|
array_t(handle h, stolen_t) : array(h, stolen) { }
|
||||||
@ -822,7 +824,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
|||||||
template <typename T, typename SFINAE> struct npy_format_descriptor {
|
template <typename T, typename SFINAE> struct npy_format_descriptor {
|
||||||
static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
|
static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
|
||||||
|
|
||||||
static PYBIND11_DESCR name() { return _("struct"); }
|
static PYBIND11_DESCR name() { return make_caster<T>::name(); }
|
||||||
|
|
||||||
static pybind11::dtype dtype() {
|
static pybind11::dtype dtype() {
|
||||||
return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
|
return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
|
||||||
@ -1140,7 +1142,9 @@ struct vectorize_helper {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
|
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
|
||||||
static PYBIND11_DESCR name() { return _("numpy.ndarray[") + make_caster<T>::name() + _("]"); }
|
static PYBIND11_DESCR name() {
|
||||||
|
return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
NAMESPACE_END(detail)
|
NAMESPACE_END(detail)
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
using arr = py::array;
|
using arr = py::array;
|
||||||
using arr_t = py::array_t<uint16_t, 0>;
|
using arr_t = py::array_t<uint16_t, 0>;
|
||||||
|
static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
|
||||||
|
|
||||||
template<typename... Ix> arr data(const arr& a, Ix... index) {
|
template<typename... Ix> arr data(const arr& a, Ix... index) {
|
||||||
return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
|
return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
|
||||||
|
@ -279,6 +279,21 @@ def test_overload_resolution(msg):
|
|||||||
# No exact match, should call first convertible version:
|
# No exact match, should call first convertible version:
|
||||||
assert overloaded(np.array([1], dtype='uint8')) == 'double'
|
assert overloaded(np.array([1], dtype='uint8')) == 'double'
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as excinfo:
|
||||||
|
overloaded("not an array")
|
||||||
|
assert msg(excinfo.value) == """
|
||||||
|
overloaded(): incompatible function arguments. The following argument types are supported:
|
||||||
|
1. (arg0: numpy.ndarray[float64]) -> str
|
||||||
|
2. (arg0: numpy.ndarray[float32]) -> str
|
||||||
|
3. (arg0: numpy.ndarray[int32]) -> str
|
||||||
|
4. (arg0: numpy.ndarray[uint16]) -> str
|
||||||
|
5. (arg0: numpy.ndarray[int64]) -> str
|
||||||
|
6. (arg0: numpy.ndarray[complex128]) -> str
|
||||||
|
7. (arg0: numpy.ndarray[complex64]) -> str
|
||||||
|
|
||||||
|
Invoked with: 'not an array'
|
||||||
|
"""
|
||||||
|
|
||||||
assert overloaded2(np.array([1], dtype='float64')) == 'double'
|
assert overloaded2(np.array([1], dtype='float64')) == 'double'
|
||||||
assert overloaded2(np.array([1], dtype='float32')) == 'float'
|
assert overloaded2(np.array([1], dtype='float32')) == 'float'
|
||||||
assert overloaded2(np.array([1], dtype='complex64')) == 'float complex'
|
assert overloaded2(np.array([1], dtype='complex64')) == 'float complex'
|
||||||
@ -289,8 +304,8 @@ def test_overload_resolution(msg):
|
|||||||
assert overloaded3(np.array([1], dtype='intc')) == 'int'
|
assert overloaded3(np.array([1], dtype='intc')) == 'int'
|
||||||
expected_exc = """
|
expected_exc = """
|
||||||
overloaded3(): incompatible function arguments. The following argument types are supported:
|
overloaded3(): incompatible function arguments. The following argument types are supported:
|
||||||
1. (arg0: numpy.ndarray[int]) -> str
|
1. (arg0: numpy.ndarray[int32]) -> str
|
||||||
2. (arg0: numpy.ndarray[float]) -> str
|
2. (arg0: numpy.ndarray[float64]) -> str
|
||||||
|
|
||||||
Invoked with:"""
|
Invoked with:"""
|
||||||
|
|
||||||
|
@ -71,5 +71,5 @@ def test_docs(doc):
|
|||||||
from pybind11_tests import vectorized_func
|
from pybind11_tests import vectorized_func
|
||||||
|
|
||||||
assert doc(vectorized_func) == """
|
assert doc(vectorized_func) == """
|
||||||
vectorized_func(arg0: numpy.ndarray[int], arg1: numpy.ndarray[float], arg2: numpy.ndarray[float]) -> object
|
vectorized_func(arg0: numpy.ndarray[int32], arg1: numpy.ndarray[float32], arg2: numpy.ndarray[float64]) -> object
|
||||||
""" # noqa: E501 line too long
|
""" # noqa: E501 line too long
|
||||||
|
Loading…
Reference in New Issue
Block a user