mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-23 05:35:13 +00:00
Add Add py::numpy_scalar test.
This commit is contained in:
parent
74b027e586
commit
00a85a6969
@ -120,6 +120,7 @@ set(PYBIND11_TEST_FILES
|
||||
test_multiple_inheritance.cpp
|
||||
test_numpy_array.cpp
|
||||
test_numpy_dtypes.cpp
|
||||
test_numpy_scalars.cpp
|
||||
test_numpy_vectorize.cpp
|
||||
test_opaque_types.cpp
|
||||
test_operator_overloading.cpp
|
||||
|
49
tests/test_numpy_scalars.cpp
Normal file
49
tests/test_numpy_scalars.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
tests/test_numpy_scalars.cpp -- strict NumPy scalars
|
||||
|
||||
Copyright (c) 2021 Steve R. Sun
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
||||
#include "pybind11_tests.h"
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template<typename T>
|
||||
struct add {
|
||||
T x;
|
||||
add(T x) : x(x) {}
|
||||
T operator()(T y) const { return static_cast<T>(x + y); }
|
||||
};
|
||||
|
||||
template<typename T, typename F>
|
||||
void register_test(py::module& m, const char *name, F&& func) {
|
||||
m.def((std::string("test_") + name).c_str(), [=](py::numpy_scalar<T> v) {
|
||||
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
|
||||
}, py::arg("x"));
|
||||
}
|
||||
|
||||
TEST_SUBMODULE(numpy_scalars, m) {
|
||||
using cfloat = std::complex<float>;
|
||||
using cdouble = std::complex<double>;
|
||||
|
||||
register_test<bool>(m, "bool", [](bool x) { return !x; });
|
||||
register_test<int8_t>(m, "int8", add<int8_t>(-8));
|
||||
register_test<int16_t>(m, "int16", add<int16_t>(-16));
|
||||
register_test<int32_t>(m, "int32", add<int32_t>(-32));
|
||||
register_test<int64_t>(m, "int64", add<int64_t>(-64));
|
||||
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
|
||||
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
|
||||
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
|
||||
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
|
||||
register_test<float>(m, "float32", add<float>(0.125f));
|
||||
register_test<double>(m, "float64", add<double>(0.25f));
|
||||
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
|
||||
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25f}));
|
||||
}
|
61
tests/test_numpy_scalars.py
Normal file
61
tests/test_numpy_scalars.py
Normal file
@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from pybind11_tests import numpy_scalars as m
|
||||
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
SCALAR_TYPES = dict([
|
||||
(np.bool_, False),
|
||||
(np.int8, -7),
|
||||
(np.int16, -15),
|
||||
(np.int32, -31),
|
||||
(np.int64, -63),
|
||||
(np.uint8, 9),
|
||||
(np.uint16, 17),
|
||||
(np.uint32, 33),
|
||||
(np.uint64, 65),
|
||||
(np.single, 1.125),
|
||||
(np.double, 1.25),
|
||||
(np.complex64, 1 - 0.125j),
|
||||
(np.complex128, 1 - 0.25j),
|
||||
])
|
||||
ALL_TYPES = [int, bool, float, bytes, str] + list(SCALAR_TYPES)
|
||||
|
||||
|
||||
def type_name(tp):
|
||||
try:
|
||||
return tp.__name__.rstrip('_')
|
||||
except BaseException:
|
||||
# no numpy
|
||||
return str(tp)
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', params=list(SCALAR_TYPES), ids=type_name)
|
||||
def scalar_type(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def expected_signature(tp):
|
||||
s = 'str' if sys.version_info[0] >= 3 else 'unicode'
|
||||
t = type_name(tp)
|
||||
return 'test_{t}(x: {t}) -> Tuple[{s}, {t}]\n'.format(s=s, t=t)
|
||||
|
||||
|
||||
def test_numpy_scalars(scalar_type):
|
||||
expected = SCALAR_TYPES[scalar_type]
|
||||
name = type_name(scalar_type)
|
||||
func = getattr(m, 'test_' + name)
|
||||
assert func.__doc__ == expected_signature(scalar_type)
|
||||
for tp in ALL_TYPES:
|
||||
value = tp(1)
|
||||
if tp is scalar_type:
|
||||
result = func(value)
|
||||
assert result[0] == name
|
||||
assert isinstance(result[1], tp)
|
||||
assert result[1] == tp(expected)
|
||||
else:
|
||||
with pytest.raises(TypeError):
|
||||
func(value)
|
Loading…
Reference in New Issue
Block a user