2016-05-05 18:33:54 +00:00
|
|
|
/*
|
|
|
|
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
|
|
|
|
|
|
|
|
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
|
|
|
|
|
|
|
All rights reserved. Use of this source code is governed by a
|
|
|
|
BSD-style license that can be found in the LICENSE file.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "numpy.h"
|
2016-05-29 11:40:40 +00:00
|
|
|
|
2016-09-07 14:37:40 +00:00
|
|
|
#if defined(__INTEL_COMPILER)
|
|
|
|
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
|
|
|
|
#elif defined(__GNUG__) || defined(__clang__)
|
2016-05-29 11:40:40 +00:00
|
|
|
# pragma GCC diagnostic push
|
|
|
|
# pragma GCC diagnostic ignored "-Wconversion"
|
2016-05-30 09:37:03 +00:00
|
|
|
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
2016-12-14 00:09:08 +00:00
|
|
|
# if __GNUC__ >= 7
|
|
|
|
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
|
|
|
|
# endif
|
2016-05-29 11:40:40 +00:00
|
|
|
#endif
|
|
|
|
|
2016-05-05 18:33:54 +00:00
|
|
|
#include <Eigen/Core>
|
|
|
|
#include <Eigen/SparseCore>
|
|
|
|
|
|
|
|
#if defined(_MSC_VER)
|
2016-12-14 00:09:08 +00:00
|
|
|
# pragma warning(push)
|
|
|
|
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
2016-05-05 18:33:54 +00:00
|
|
|
#endif
|
|
|
|
|
|
|
|
NAMESPACE_BEGIN(pybind11)
|
|
|
|
NAMESPACE_BEGIN(detail)
|
|
|
|
|
2016-09-24 21:54:02 +00:00
|
|
|
template <typename T> using is_eigen_dense = is_template_base_of<Eigen::DenseBase, T>;
|
|
|
|
template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
|
|
|
|
template <typename T> using is_eigen_ref = is_template_base_of<Eigen::RefBase, T>;
|
2016-05-05 18:33:54 +00:00
|
|
|
|
2016-08-04 19:24:41 +00:00
|
|
|
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
|
|
|
|
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
|
|
|
|
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
|
|
|
|
// SelfAdjointView fall into this category.
|
Change all_of_t/any_of_t to all_of/any_of, add none_of
This replaces the current `all_of_t<Pred, Ts...>` with `all_of<Ts...>`,
with previous use of `all_of_t<Pred, Ts...>` becoming
`all_of<Pred<Ts>...>` (and similarly for `any_of_t`). It also adds a
`none_of<Ts...>`, a shortcut for `negation<any_of<Ts...>>`.
This allows `all_of` and `any_of` to be used a bit more flexible, e.g.
in cases where several predicates need to be tested for the same type
instead of the same predicate for multiple types.
This commit replaces the implementation with a more efficient version
for non-MSVC. For MSVC, this changes the workaround to use the
built-in, recursive std::conjunction/std::disjunction instead.
This also removes the `count_t` since `any_of_t` and `all_of_t` were the
only things using it.
This commit also rearranges some of the future std imports to use actual
`std` implementations for C++14/17 features when under the appropriate
compiler mode, as we were already doing for a few things (like
index_sequence). Most of these aren't saving much (the implementation
for enable_if_t, for example, is trivial), but I think it makes the
intention of the code instantly clear. It also enables MSVC's native
std::index_sequence support.
2016-12-12 23:11:49 +00:00
|
|
|
template <typename T> using is_eigen_base = all_of<
|
|
|
|
is_template_base_of<Eigen::EigenBase, T>,
|
|
|
|
negation<is_eigen_dense<T>>,
|
|
|
|
negation<is_eigen_sparse<T>>
|
2016-09-24 21:54:02 +00:00
|
|
|
>;
|
2016-08-04 19:24:41 +00:00
|
|
|
|
2016-05-05 18:33:54 +00:00
|
|
|
template<typename Type>
|
2016-09-12 15:36:43 +00:00
|
|
|
struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>> {
|
2016-05-05 18:33:54 +00:00
|
|
|
typedef typename Type::Scalar Scalar;
|
|
|
|
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
|
2016-05-20 10:00:56 +00:00
|
|
|
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
2016-05-05 18:33:54 +00:00
|
|
|
|
|
|
|
bool load(handle src, bool) {
|
2016-11-16 00:35:22 +00:00
|
|
|
auto buf = array_t<Scalar>::ensure(src);
|
2016-10-23 12:50:08 +00:00
|
|
|
if (!buf)
|
2016-08-29 01:41:05 +00:00
|
|
|
return false;
|
2016-05-05 18:33:54 +00:00
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
if (buf.ndim() == 1) {
|
2016-07-05 19:05:10 +00:00
|
|
|
typedef Eigen::InnerStride<> Strides;
|
2016-05-20 10:00:56 +00:00
|
|
|
if (!isVector &&
|
2016-05-05 18:33:54 +00:00
|
|
|
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
|
|
|
|
Type::ColsAtCompileTime == Eigen::Dynamic))
|
|
|
|
return false;
|
|
|
|
|
|
|
|
if (Type::SizeAtCompileTime != Eigen::Dynamic &&
|
2016-08-29 01:41:05 +00:00
|
|
|
buf.shape(0) != (size_t) Type::SizeAtCompileTime)
|
2016-05-05 18:33:54 +00:00
|
|
|
return false;
|
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
Strides::Index n_elts = (Strides::Index) buf.shape(0);
|
2016-07-05 19:05:10 +00:00
|
|
|
Strides::Index unity = 1;
|
2016-05-05 18:33:54 +00:00
|
|
|
|
|
|
|
value = Eigen::Map<Type, 0, Strides>(
|
2016-08-29 01:41:05 +00:00
|
|
|
buf.mutable_data(),
|
|
|
|
rowMajor ? unity : n_elts,
|
|
|
|
rowMajor ? n_elts : unity,
|
|
|
|
Strides(buf.strides(0) / sizeof(Scalar))
|
|
|
|
);
|
|
|
|
} else if (buf.ndim() == 2) {
|
2016-05-05 18:33:54 +00:00
|
|
|
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
|
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
|
|
|
|
(Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
|
2016-05-05 18:33:54 +00:00
|
|
|
return false;
|
|
|
|
|
|
|
|
value = Eigen::Map<Type, 0, Strides>(
|
2016-08-29 01:41:05 +00:00
|
|
|
buf.mutable_data(),
|
|
|
|
typename Strides::Index(buf.shape(0)),
|
|
|
|
typename Strides::Index(buf.shape(1)),
|
|
|
|
Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
|
|
|
|
buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
|
|
|
|
);
|
2016-05-05 18:33:54 +00:00
|
|
|
} else {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
|
2016-05-20 10:00:56 +00:00
|
|
|
if (isVector) {
|
2016-08-15 00:24:59 +00:00
|
|
|
return array(
|
|
|
|
{ (size_t) src.size() }, // shape
|
|
|
|
{ sizeof(Scalar) * static_cast<size_t>(src.innerStride()) }, // strides
|
|
|
|
src.data() // data
|
|
|
|
).release();
|
2016-05-20 10:00:56 +00:00
|
|
|
} else {
|
2016-08-15 00:24:59 +00:00
|
|
|
return array(
|
|
|
|
{ (size_t) src.rows(), // shape
|
2016-05-20 10:00:56 +00:00
|
|
|
(size_t) src.cols() },
|
2016-08-15 00:24:59 +00:00
|
|
|
{ sizeof(Scalar) * static_cast<size_t>(src.rowStride()), // strides
|
|
|
|
sizeof(Scalar) * static_cast<size_t>(src.colStride()) },
|
|
|
|
src.data() // data
|
|
|
|
).release();
|
2016-05-20 10:00:56 +00:00
|
|
|
}
|
2016-05-05 18:33:54 +00:00
|
|
|
}
|
|
|
|
|
2016-08-03 23:40:40 +00:00
|
|
|
PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
|
|
|
|
_("[") + rows() + _(", ") + cols() + _("]]"));
|
2016-05-05 18:33:54 +00:00
|
|
|
|
2016-05-24 19:39:41 +00:00
|
|
|
protected:
|
2016-09-12 15:36:43 +00:00
|
|
|
template <typename T = Type, enable_if_t<T::RowsAtCompileTime == Eigen::Dynamic, int> = 0>
|
2016-05-05 18:33:54 +00:00
|
|
|
static PYBIND11_DESCR rows() { return _("m"); }
|
2016-09-12 15:36:43 +00:00
|
|
|
template <typename T = Type, enable_if_t<T::RowsAtCompileTime != Eigen::Dynamic, int> = 0>
|
2016-05-05 18:33:54 +00:00
|
|
|
static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); }
|
2016-09-12 15:36:43 +00:00
|
|
|
template <typename T = Type, enable_if_t<T::ColsAtCompileTime == Eigen::Dynamic, int> = 0>
|
2016-05-05 18:33:54 +00:00
|
|
|
static PYBIND11_DESCR cols() { return _("n"); }
|
2016-09-12 15:36:43 +00:00
|
|
|
template <typename T = Type, enable_if_t<T::ColsAtCompileTime != Eigen::Dynamic, int> = 0>
|
2016-05-05 18:33:54 +00:00
|
|
|
static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
|
|
|
|
};
|
|
|
|
|
2016-09-24 21:54:02 +00:00
|
|
|
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructable, so it needs a special
|
|
|
|
// type_caster to handle argument copying/forwarding.
|
|
|
|
template <typename CVDerived, int Options, typename StrideType>
|
|
|
|
struct type_caster<Eigen::Ref<CVDerived, Options, StrideType>> {
|
2016-08-04 19:24:41 +00:00
|
|
|
protected:
|
2016-09-24 21:54:02 +00:00
|
|
|
using Type = Eigen::Ref<CVDerived, Options, StrideType>;
|
|
|
|
using Derived = typename std::remove_const<CVDerived>::type;
|
2017-01-03 10:52:05 +00:00
|
|
|
using DerivedCaster = make_caster<Derived>;
|
2016-08-03 20:50:22 +00:00
|
|
|
DerivedCaster derived_caster;
|
|
|
|
std::unique_ptr<Type> value;
|
|
|
|
public:
|
|
|
|
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
|
|
|
|
static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
|
|
|
|
static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
|
|
|
|
|
|
|
|
static PYBIND11_DESCR name() { return DerivedCaster::name(); }
|
|
|
|
|
|
|
|
operator Type*() { return value.get(); }
|
|
|
|
operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
|
|
|
|
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
|
|
|
};
|
|
|
|
|
2016-08-04 19:24:41 +00:00
|
|
|
// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can
|
|
|
|
// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that.
|
|
|
|
template <typename Type>
|
2016-09-12 15:36:43 +00:00
|
|
|
struct type_caster<Type, enable_if_t<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>> {
|
2016-08-04 19:24:41 +00:00
|
|
|
protected:
|
|
|
|
using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>;
|
2017-01-03 10:52:05 +00:00
|
|
|
using MatrixCaster = make_caster<Matrix>;
|
2016-08-04 19:24:41 +00:00
|
|
|
public:
|
|
|
|
[[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); }
|
|
|
|
static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); }
|
|
|
|
static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); }
|
|
|
|
|
|
|
|
static PYBIND11_DESCR name() { return MatrixCaster::name(); }
|
|
|
|
|
|
|
|
[[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
|
|
|
|
[[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
|
|
|
|
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
|
|
|
};
|
|
|
|
|
2016-05-05 18:33:54 +00:00
|
|
|
template<typename Type>
|
2016-09-12 15:36:43 +00:00
|
|
|
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
|
2016-05-05 18:33:54 +00:00
|
|
|
typedef typename Type::Scalar Scalar;
|
|
|
|
typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
|
|
|
|
typedef typename Type::Index Index;
|
|
|
|
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
|
|
|
|
|
|
|
|
bool load(handle src, bool) {
|
2016-05-10 14:59:01 +00:00
|
|
|
if (!src)
|
|
|
|
return false;
|
|
|
|
|
2016-10-28 01:08:15 +00:00
|
|
|
auto obj = reinterpret_borrow<object>(src);
|
2016-05-05 18:33:54 +00:00
|
|
|
object sparse_module = module::import("scipy.sparse");
|
|
|
|
object matrix_type = sparse_module.attr(
|
|
|
|
rowMajor ? "csr_matrix" : "csc_matrix");
|
|
|
|
|
|
|
|
if (obj.get_type() != matrix_type.ptr()) {
|
|
|
|
try {
|
2016-05-08 12:34:09 +00:00
|
|
|
obj = matrix_type(obj);
|
2016-05-05 18:33:54 +00:00
|
|
|
} catch (const error_already_set &) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
auto values = array_t<Scalar>((object) obj.attr("data"));
|
|
|
|
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
|
|
|
|
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
|
2016-05-05 18:33:54 +00:00
|
|
|
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
|
|
|
auto nnz = obj.attr("nnz").cast<Index>();
|
|
|
|
|
2016-10-23 12:50:08 +00:00
|
|
|
if (!values || !innerIndices || !outerIndices)
|
2016-05-05 18:33:54 +00:00
|
|
|
return false;
|
|
|
|
|
|
|
|
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
|
2016-08-29 01:41:05 +00:00
|
|
|
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
|
|
|
|
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
|
2016-05-05 18:33:54 +00:00
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
|
|
|
|
const_cast<Type&>(src).makeCompressed();
|
|
|
|
|
|
|
|
object matrix_type = module::import("scipy.sparse").attr(
|
|
|
|
rowMajor ? "csr_matrix" : "csc_matrix");
|
|
|
|
|
2016-08-15 00:24:59 +00:00
|
|
|
array data((size_t) src.nonZeros(), src.valuePtr());
|
|
|
|
array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
|
|
|
|
array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
|
2016-05-05 18:33:54 +00:00
|
|
|
|
2016-05-08 12:34:09 +00:00
|
|
|
return matrix_type(
|
2016-05-05 18:33:54 +00:00
|
|
|
std::make_tuple(data, innerIndices, outerIndices),
|
|
|
|
std::make_pair(src.rows(), src.cols())
|
|
|
|
).release();
|
|
|
|
}
|
|
|
|
|
2016-08-03 23:40:40 +00:00
|
|
|
PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
|
2016-07-06 04:40:54 +00:00
|
|
|
+ npy_format_descriptor<Scalar>::name() + _("]"));
|
2016-05-05 18:33:54 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
NAMESPACE_END(detail)
|
|
|
|
NAMESPACE_END(pybind11)
|
|
|
|
|
2016-12-14 00:09:08 +00:00
|
|
|
#if defined(__GNUG__) || defined(__clang__)
|
|
|
|
# pragma GCC diagnostic pop
|
|
|
|
#elif defined(_MSC_VER)
|
|
|
|
# pragma warning(pop)
|
2016-05-05 18:33:54 +00:00
|
|
|
#endif
|