eigen.h: return compile time vectors as 1D NumPy arrays

This commit is contained in:
Wenzel Jakob 2016-05-20 12:00:56 +02:00
parent b47a9de035
commit a970a579b2

View File

@ -41,6 +41,7 @@ template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> { struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
typedef typename Type::Scalar Scalar; typedef typename Type::Scalar Scalar;
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
static constexpr bool isVector = Type::IsVectorAtCompileTime;
bool load(handle src, bool) { bool load(handle src, bool) {
array_t<Scalar> buffer(src, true); array_t<Scalar> buffer(src, true);
@ -50,7 +51,7 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::t
buffer_info info = buffer.request(); buffer_info info = buffer.request();
if (info.ndim == 1) { if (info.ndim == 1) {
typedef Eigen::Stride<Eigen::Dynamic, 0> Strides; typedef Eigen::Stride<Eigen::Dynamic, 0> Strides;
if (!Type::IsVectorAtCompileTime && if (!isVector &&
!(Type::RowsAtCompileTime == Eigen::Dynamic && !(Type::RowsAtCompileTime == Eigen::Dynamic &&
Type::ColsAtCompileTime == Eigen::Dynamic)) Type::ColsAtCompileTime == Eigen::Dynamic))
return false; return false;
@ -87,7 +88,8 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::t
} }
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
array result(buffer_info( if (isVector) {
return array(buffer_info(
/* Pointer to buffer */ /* Pointer to buffer */
const_cast<Scalar *>(src.data()), const_cast<Scalar *>(src.data()),
/* Size of one scalar */ /* Size of one scalar */
@ -95,15 +97,30 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::t
/* Python struct-style format descriptor */ /* Python struct-style format descriptor */
format_descriptor<Scalar>::value, format_descriptor<Scalar>::value,
/* Number of dimensions */ /* Number of dimensions */
2, 1,
/* Buffer dimensions */
{ (size_t) src.size() },
/* Strides (in bytes) for each index */
{ sizeof(Scalar) }
)).release();
} else {
return array(buffer_info(
/* Pointer to buffer */
const_cast<Scalar *>(src.data()),
/* Size of one scalar */
sizeof(Scalar),
/* Python struct-style format descriptor */
format_descriptor<Scalar>::value,
/* Number of dimensions */
isVector ? 1 : 2,
/* Buffer dimensions */ /* Buffer dimensions */
{ (size_t) src.rows(), { (size_t) src.rows(),
(size_t) src.cols() }, (size_t) src.cols() },
/* Strides (in bytes) for each index */ /* Strides (in bytes) for each index */
{ sizeof(Scalar) * (rowMajor ? src.cols() : 1), { sizeof(Scalar) * (rowMajor ? src.cols() : 1),
sizeof(Scalar) * (rowMajor ? 1 : src.rows()) } sizeof(Scalar) * (rowMajor ? 1 : src.rows()) }
)); )).release();
return result.release(); }
} }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;