/* pybind/numpy.h: Basic NumPy support, auto-vectorization support Copyright (c) 2015 Wenzel Jakob All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ #pragma once #include #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif NAMESPACE_BEGIN(pybind) template struct npy_format_descriptor { }; class array : public buffer { public: struct API { enum Entries { API_PyArray_Type = 2, API_PyArray_DescrFromType = 45, API_PyArray_FromAny = 69, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, NPY_C_CONTIGUOUS = 0x0001, NPY_F_CONTIGUOUS = 0x0002, NPY_NPY_ARRAY_FORCECAST = 0x0010, NPY_ENSURE_ARRAY = 0x0040, NPY_BOOL=0, NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT, NPY_INT, NPY_UINT, NPY_LONG, NPY_ULONG, NPY_LONGLONG, NPY_ULONGLONG, NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE }; static API lookup() { PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray"); PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr; void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr); Py_XDECREF(capsule); Py_XDECREF(numpy); if (api_ptr == nullptr) throw std::runtime_error("Could not acquire pointer to NumPy API!"); API api; api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type]; api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType]; api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny]; api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy]; api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr]; return api; } bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); } PyObject *(*PyArray_DescrFromType)(int); PyObject *(*PyArray_NewFromDescr) (PyTypeObject *, PyObject *, int, Py_intptr_t *, Py_intptr_t *, void *, int, PyObject *); PyObject *(*PyArray_NewCopy)(PyObject *, int); PyTypeObject *PyArray_Type; PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *); }; PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check) template array(size_t size, const Type *ptr) { API& api = lookup_api(); PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor::value); if (descr == nullptr) throw std::runtime_error("NumPy: unsupported buffer format!"); Py_intptr_t shape = (Py_intptr_t) size; PyObject *tmp = api.PyArray_NewFromDescr( api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr); if (tmp == nullptr) throw std::runtime_error("NumPy: unable to create array!"); m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */); Py_DECREF(tmp); if (m_ptr == nullptr) throw std::runtime_error("NumPy: unable to copy array!"); } array(const buffer_info &info) { API& api = lookup_api(); if (info.format.size() != 1) throw std::runtime_error("Unsupported buffer format!"); int fmt = (int) info.format[0]; if (info.format == "Zd") fmt = API::NPY_CDOUBLE; else if (info.format == "Zf") fmt = API::NPY_CFLOAT; PyObject *descr = api.PyArray_DescrFromType(fmt); if (descr == nullptr) throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!"); PyObject *tmp = api.PyArray_NewFromDescr( api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0], (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr); if (tmp == nullptr) throw std::runtime_error("NumPy: unable to create array!"); m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */); Py_DECREF(tmp); if (m_ptr == nullptr) throw std::runtime_error("NumPy: unable to copy array!"); } protected: static API &lookup_api() { static API api = API::lookup(); return api; } }; template class array_dtype : public array { public: PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr)); array_dtype() : array() { } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } PyObject *ensure(PyObject *ptr) { API &api = lookup_api(); PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor::value); return api.PyArray_FromAny(ptr, descr, 0, 0, API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY | API::NPY_NPY_ARRAY_FORCECAST, nullptr); } }; NAMESPACE_BEGIN(detail) PYBIND_TYPE_CASTER_PYTYPE(array) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) NAMESPACE_END(detail) #define DECL_FMT(t, n) template<> struct npy_format_descriptor { enum { value = array::API::n }; } DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT); DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT); DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT); DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex, NPY_CFLOAT); DECL_FMT(std::complex, NPY_CDOUBLE); #undef DECL_FMT template std::function...)> vectorize(func_type &&f, return_type (*) (args_type ...), detail::index_sequence) { return [f](array_dtype... args) -> array { /* Request buffers from all parameters */ const size_t N = sizeof...(args_type); std::array buffers {{ args.request()... }}; /* Determine dimensions parameters of output array */ int ndim = 0; size_t count = 0; std::vector shape; for (size_t i=0; i count) { ndim = buffers[i].ndim; shape = buffers[i].shape; count = buffers[i].count; } } std::vector strides(ndim); if (ndim > 0) { strides[ndim-1] = sizeof(return_type); for (int i=ndim-1; i>0; --i) strides[i-1] = strides[i] * shape[i]; } /* Check if the parameters are actually compatible */ for (size_t i=0; i result(count); for (size_t i=0; i::value(), ndim, shape, strides)); }; } template std::function...)> vectorize(func_type &&f, return_type (*f_) (args_type ...) = nullptr) { return vectorize(f, f_, typename detail::make_index_sequence::type()); } template std::function...)> vectorize(return_type (*f) (args_type ...)) { return vectorize(f, f); } template auto vectorize(func &&f) -> decltype( vectorize(std::forward(f), (typename detail::remove_class::type::operator())>::type *) nullptr)) { return vectorize(std::forward(f), (typename detail::remove_class::type::operator())>::type *) nullptr); } NAMESPACE_END(pybind) #if defined(_MSC_VER) #pragma warning(pop) #endif