From 00a85a6969a8ff028f7e12fdb623bef7fff8ec80 Mon Sep 17 00:00:00 2001 From: sun1638650145 <1638650145@qq.com> Date: Fri, 10 Dec 2021 15:58:52 +0800 Subject: [PATCH] Add Add py::numpy_scalar test. --- tests/CMakeLists.txt | 1 + tests/test_numpy_scalars.cpp | 49 +++++++++++++++++++++++++++++ tests/test_numpy_scalars.py | 61 ++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 tests/test_numpy_scalars.cpp create mode 100644 tests/test_numpy_scalars.py diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8b19b1967..8b6d0a9ac 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -120,6 +120,7 @@ set(PYBIND11_TEST_FILES test_multiple_inheritance.cpp test_numpy_array.cpp test_numpy_dtypes.cpp + test_numpy_scalars.cpp test_numpy_vectorize.cpp test_opaque_types.cpp test_operator_overloading.cpp diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp new file mode 100644 index 000000000..c8f8fcafe --- /dev/null +++ b/tests/test_numpy_scalars.cpp @@ -0,0 +1,49 @@ +/* + tests/test_numpy_scalars.cpp -- strict NumPy scalars + + Copyright (c) 2021 Steve R. Sun + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include +#include + +#include "pybind11_tests.h" +#include + +namespace py = pybind11; + +template +struct add { + T x; + add(T x) : x(x) {} + T operator()(T y) const { return static_cast(x + y); } +}; + +template +void register_test(py::module& m, const char *name, F&& func) { + m.def((std::string("test_") + name).c_str(), [=](py::numpy_scalar v) { + return std::make_tuple(name, py::make_scalar(static_cast(func(v.value)))); + }, py::arg("x")); +} + +TEST_SUBMODULE(numpy_scalars, m) { + using cfloat = std::complex; + using cdouble = std::complex; + + register_test(m, "bool", [](bool x) { return !x; }); + register_test(m, "int8", add(-8)); + register_test(m, "int16", add(-16)); + register_test(m, "int32", add(-32)); + register_test(m, "int64", add(-64)); + register_test(m, "uint8", add(8)); + register_test(m, "uint16", add(16)); + register_test(m, "uint32", add(32)); + register_test(m, "uint64", add(64)); + register_test(m, "float32", add(0.125f)); + register_test(m, "float64", add(0.25f)); + register_test(m, "complex64", add({0, -0.125f})); + register_test(m, "complex128", add({0, -0.25f})); +} diff --git a/tests/test_numpy_scalars.py b/tests/test_numpy_scalars.py new file mode 100644 index 000000000..a63093dad --- /dev/null +++ b/tests/test_numpy_scalars.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +import sys + +import pytest + +from pybind11_tests import numpy_scalars as m + +np = pytest.importorskip("numpy") + +SCALAR_TYPES = dict([ + (np.bool_, False), + (np.int8, -7), + (np.int16, -15), + (np.int32, -31), + (np.int64, -63), + (np.uint8, 9), + (np.uint16, 17), + (np.uint32, 33), + (np.uint64, 65), + (np.single, 1.125), + (np.double, 1.25), + (np.complex64, 1 - 0.125j), + (np.complex128, 1 - 0.25j), +]) +ALL_TYPES = [int, bool, float, bytes, str] + list(SCALAR_TYPES) + + +def type_name(tp): + try: + return tp.__name__.rstrip('_') + except BaseException: + # no numpy + return str(tp) + + +@pytest.fixture(scope='module', params=list(SCALAR_TYPES), ids=type_name) +def scalar_type(request): + return request.param + + +def expected_signature(tp): + s = 'str' if sys.version_info[0] >= 3 else 'unicode' + t = type_name(tp) + return 'test_{t}(x: {t}) -> Tuple[{s}, {t}]\n'.format(s=s, t=t) + + +def test_numpy_scalars(scalar_type): + expected = SCALAR_TYPES[scalar_type] + name = type_name(scalar_type) + func = getattr(m, 'test_' + name) + assert func.__doc__ == expected_signature(scalar_type) + for tp in ALL_TYPES: + value = tp(1) + if tp is scalar_type: + result = func(value) + assert result[0] == name + assert isinstance(result[1], tp) + assert result[1] == tp(expected) + else: + with pytest.raises(TypeError): + func(value)