Merge pull request #334 from aldanor/bugfix/string-descriptors

Fix format descriptors for string types
This commit is contained in:
Wenzel Jakob 2016-08-15 06:47:13 +02:00 committed by GitHub
commit 6a0a850742
2 changed files with 24 additions and 22 deletions

View File

@ -333,14 +333,14 @@ PYBIND11_RUNTIME_EXCEPTION(reference_cast_error) /// Used internally
/// Format strings for basic number types /// Format strings for basic number types
#define PYBIND11_DECL_FMT(t, v) template<> struct format_descriptor<t> \ #define PYBIND11_DECL_FMT(t, v) template<> struct format_descriptor<t> \
{ static constexpr const char* value = v; /* for backwards compatibility */ \ { static constexpr const char* value = v; /* for backwards compatibility */ \
static constexpr const char* format() { return value; } } static std::string format() { return value; } }
template <typename T, typename SFINAE = void> struct format_descriptor { }; template <typename T, typename SFINAE = void> struct format_descriptor { };
template <typename T> struct format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> { template <typename T> struct format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
static constexpr const char value[2] = static constexpr const char value[2] =
{ "bBhHiIqQ"[detail::log2(sizeof(T))*2 + (std::is_unsigned<T>::value ? 1 : 0)], '\0' }; { "bBhHiIqQ"[detail::log2(sizeof(T))*2 + (std::is_unsigned<T>::value ? 1 : 0)], '\0' };
static constexpr const char* format() { return value; } static std::string format() { return value; }
}; };
template <typename T> constexpr const char format_descriptor< template <typename T> constexpr const char format_descriptor<

View File

@ -17,6 +17,7 @@
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <sstream> #include <sstream>
#include <string>
#include <initializer_list> #include <initializer_list>
#if defined(_MSC_VER) #if defined(_MSC_VER)
@ -303,14 +304,14 @@ public:
template <typename T> template <typename T>
struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> { struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> {
static const char *format() { return detail::npy_format_descriptor<T>::format(); } static std::string format() { return detail::npy_format_descriptor<T>::format(); }
}; };
template <size_t N> struct format_descriptor<char[N]> { template <size_t N> struct format_descriptor<char[N]> {
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); } static std::string format() { return std::to_string(N) + "s"; }
}; };
template <size_t N> struct format_descriptor<std::array<char, N>> { template <size_t N> struct format_descriptor<std::array<char, N>> {
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); } static std::string format() { return std::to_string(N) + "s"; }
}; };
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
@ -367,11 +368,7 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \ #define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \ static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static pybind11::dtype dtype() { \ static pybind11::dtype dtype() { return std::string("S") + std::to_string(N); }
PYBIND11_DESCR fmt = _("S") + _<N>(); \
return pybind11::dtype(fmt.text()); \
} \
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT }; template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT }; template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
#undef DECL_CHAR_FMT #undef DECL_CHAR_FMT
@ -380,7 +377,7 @@ struct field_descriptor {
const char *name; const char *name;
size_t offset; size_t offset;
size_t size; size_t size;
const char *format; std::string format;
dtype descr; dtype descr;
}; };
@ -389,15 +386,15 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
static PYBIND11_DESCR name() { return _("struct"); } static PYBIND11_DESCR name() { return _("struct"); }
static pybind11::dtype dtype() { static pybind11::dtype dtype() {
if (!dtype_()) if (!dtype_ptr)
pybind11_fail("NumPy: unsupported buffer format!"); pybind11_fail("NumPy: unsupported buffer format!");
return object(dtype_(), true); return object(dtype_ptr, true);
} }
static const char* format() { static std::string format() {
if (!dtype_()) if (!dtype_ptr)
pybind11_fail("NumPy: unsupported buffer format!"); pybind11_fail("NumPy: unsupported buffer format!");
return format_().c_str(); return format_str;
} }
static void register_dtype(std::initializer_list<field_descriptor> fields) { static void register_dtype(std::initializer_list<field_descriptor> fields) {
@ -409,7 +406,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
formats.append(field.descr); formats.append(field.descr);
offsets.append(int_(field.offset)); offsets.append(int_(field.offset));
} }
dtype_() = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are // There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly // not encoded explicitly into the format string. This will supposedly
@ -436,20 +433,25 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
if (sizeof(T) > offset) if (sizeof(T) > offset)
oss << (sizeof(T) - offset) << 'x'; oss << (sizeof(T) - offset) << 'x';
oss << '}'; oss << '}';
format_() = oss.str(); format_str = oss.str();
// Sanity check: verify that NumPy properly parses our buffer format string // Sanity check: verify that NumPy properly parses our buffer format string
auto& api = npy_api::get(); auto& api = npy_api::get();
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) })); auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1));
if (!api.PyArray_EquivTypes_(dtype_(), arr.dtype().ptr())) if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!"); pybind11_fail("NumPy: invalid buffer descriptor!");
} }
private: private:
static inline PyObject*& dtype_() { static PyObject *ptr = nullptr; return ptr; } static std::string format_str;
static inline std::string& format_() { static std::string s; return s; } static PyObject* dtype_ptr;
}; };
template <typename T>
std::string npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type>::format_str;
template <typename T>
PyObject* npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type>::dtype_ptr = nullptr;
// Extract name, offset and format descriptor for a struct field // Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \ #define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \ ::pybind11::detail::field_descriptor { \