mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 13:15:12 +00:00
numpy: Add test for explicit dtype checks. At present, int64 + uint64 do not exactly match dtype(...).num
This commit is contained in:
parent
c6b699d9c2
commit
e9ca89f453
@ -14,6 +14,67 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
// Size / dtype checks.
|
||||||
|
struct DtypeCheck {
|
||||||
|
py::dtype numpy{};
|
||||||
|
py::dtype pybind11{};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
DtypeCheck get_dtype_check(const char* name) {
|
||||||
|
py::module np = py::module::import("numpy");
|
||||||
|
DtypeCheck check{};
|
||||||
|
check.numpy = np.attr("dtype")(np.attr(name));
|
||||||
|
check.pybind11 = py::dtype::of<T>();
|
||||||
|
return check;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DtypeCheck> get_concrete_dtype_checks() {
|
||||||
|
return {
|
||||||
|
// Normalization
|
||||||
|
get_dtype_check<std::int8_t>("int8"),
|
||||||
|
get_dtype_check<std::uint8_t>("uint8"),
|
||||||
|
get_dtype_check<std::int16_t>("int16"),
|
||||||
|
get_dtype_check<std::uint16_t>("uint16"),
|
||||||
|
get_dtype_check<std::int32_t>("int32"),
|
||||||
|
get_dtype_check<std::uint32_t>("uint32"),
|
||||||
|
get_dtype_check<std::int64_t>("int64"),
|
||||||
|
get_dtype_check<std::uint64_t>("uint64")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DtypeSizeCheck {
|
||||||
|
std::string name{};
|
||||||
|
int size_cpp{};
|
||||||
|
int size_numpy{};
|
||||||
|
// For debugging.
|
||||||
|
py::dtype dtype{};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
DtypeSizeCheck get_dtype_size_check() {
|
||||||
|
DtypeSizeCheck check{};
|
||||||
|
check.name = py::type_id<T>();
|
||||||
|
check.size_cpp = sizeof(T);
|
||||||
|
check.dtype = py::dtype::of<T>();
|
||||||
|
check.size_numpy = check.dtype.attr("itemsize").template cast<int>();
|
||||||
|
return check;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DtypeSizeCheck> get_platform_dtype_size_checks() {
|
||||||
|
return {
|
||||||
|
get_dtype_size_check<short>(),
|
||||||
|
get_dtype_size_check<unsigned short>(),
|
||||||
|
get_dtype_size_check<int>(),
|
||||||
|
get_dtype_size_check<unsigned int>(),
|
||||||
|
get_dtype_size_check<long>(),
|
||||||
|
get_dtype_size_check<unsigned long>(),
|
||||||
|
get_dtype_size_check<long long>(),
|
||||||
|
get_dtype_size_check<unsigned long long>(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Arrays.
|
||||||
using arr = py::array;
|
using arr = py::array;
|
||||||
using arr_t = py::array_t<uint16_t, 0>;
|
using arr_t = py::array_t<uint16_t, 0>;
|
||||||
static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
|
static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
|
||||||
@ -75,6 +136,26 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
try { py::module::import("numpy"); }
|
try { py::module::import("numpy"); }
|
||||||
catch (...) { return; }
|
catch (...) { return; }
|
||||||
|
|
||||||
|
// test_dtypes
|
||||||
|
py::class_<DtypeCheck>(sm, "DtypeCheck")
|
||||||
|
.def_readonly("numpy", &DtypeCheck::numpy)
|
||||||
|
.def_readonly("pybind11", &DtypeCheck::pybind11)
|
||||||
|
.def("__repr__", [](const DtypeCheck& self) {
|
||||||
|
return py::str("<DtypeCheck numpy={} pybind11={}>").format(
|
||||||
|
self.numpy, self.pybind11);
|
||||||
|
});
|
||||||
|
sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks);
|
||||||
|
|
||||||
|
py::class_<DtypeSizeCheck>(sm, "DtypeSizeCheck")
|
||||||
|
.def_readonly("name", &DtypeSizeCheck::name)
|
||||||
|
.def_readonly("size_cpp", &DtypeSizeCheck::size_cpp)
|
||||||
|
.def_readonly("size_numpy", &DtypeSizeCheck::size_numpy)
|
||||||
|
.def("__repr__", [](const DtypeSizeCheck& self) {
|
||||||
|
return py::str("<DtypeSizeCheck name='{}' size_cpp={} size_numpy={} dtype={}>").format(
|
||||||
|
self.name, self.size_cpp, self.size_numpy, self.dtype);
|
||||||
|
});
|
||||||
|
sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks);
|
||||||
|
|
||||||
// test_array_attributes
|
// test_array_attributes
|
||||||
sm.def("ndim", [](const arr& a) { return a.ndim(); });
|
sm.def("ndim", [](const arr& a) { return a.ndim(); });
|
||||||
sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
|
sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
|
||||||
|
@ -7,6 +7,21 @@ with pytest.suppress(ImportError):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtypes():
|
||||||
|
# See issue #1328.
|
||||||
|
# - Platform-dependent sizes.
|
||||||
|
for size_check in m.get_platform_dtype_size_checks():
|
||||||
|
print(size_check)
|
||||||
|
assert size_check.size_cpp == size_check.size_numpy, size_check
|
||||||
|
# - Concrete sizes.
|
||||||
|
for check in m.get_concrete_dtype_checks():
|
||||||
|
print(check)
|
||||||
|
assert check.numpy == check.pybind11, check
|
||||||
|
if check.numpy.num != check.pybind11.num:
|
||||||
|
print("NOTE: typenum mismatch for {}: {} != {}".format(
|
||||||
|
check, check.numpy.num, check.pybind11.num))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='function')
|
@pytest.fixture(scope='function')
|
||||||
def arr():
|
def arr():
|
||||||
return np.array([[1, 2, 3], [4, 5, 6]], '=u2')
|
return np.array([[1, 2, 3], [4, 5, 6]], '=u2')
|
||||||
|
Loading…
Reference in New Issue
Block a user