mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 21:25:13 +00:00
Eigen support for special matrix objects
Functions returning specialized Eigen matrices like Eigen::DiagonalMatrix and Eigen::SelfAdjointView--which inherit from EigenBase but not DenseBase--isn't currently allowed; such classes are explicitly copyable into a Matrix (by definition), and so we can support functions that return them by copying the value into a Matrix then casting that resulting dense Matrix into a numpy.ndarray. This commit does exactly that.
This commit is contained in:
parent
19637536ac
commit
9ffb3dda5f
@ -1098,6 +1098,14 @@ pybind11 will automatically and transparently convert
|
|||||||
1. Static and dynamic Eigen dense vectors and matrices to instances of
|
1. Static and dynamic Eigen dense vectors and matrices to instances of
|
||||||
``numpy.ndarray`` (and vice versa).
|
``numpy.ndarray`` (and vice versa).
|
||||||
|
|
||||||
|
1. Returned matrix expressions such as blocks (including columns or rows) and
|
||||||
|
diagonals will be converted to ``numpy.ndarray`` of the expression
|
||||||
|
values.
|
||||||
|
|
||||||
|
1. Returned matrix-like objects such as Eigen::DiagonalMatrix or
|
||||||
|
Eigen::SelfAdjointView will be converted to ``numpy.ndarray`` containing the
|
||||||
|
expressed value.
|
||||||
|
|
||||||
1. Eigen sparse vectors and matrices to instances of
|
1. Eigen sparse vectors and matrices to instances of
|
||||||
``scipy.sparse.csr_matrix``/``scipy.sparse.csc_matrix`` (and vice versa).
|
``scipy.sparse.csr_matrix``/``scipy.sparse.csc_matrix`` (and vice versa).
|
||||||
|
|
||||||
@ -1107,11 +1115,14 @@ them somehow, in which case the information won't be propagated to the caller.
|
|||||||
|
|
||||||
.. code-block:: cpp
|
.. code-block:: cpp
|
||||||
|
|
||||||
/* The Python bindings of this function won't replicate
|
/* The Python bindings of these functions won't replicate
|
||||||
the intended effect of modifying the function argument */
|
the intended effect of modifying the function arguments */
|
||||||
void scale_by_2(Eigen::Vector3f &v) {
|
void scale_by_2(Eigen::Vector3f &v) {
|
||||||
v *= 2;
|
v *= 2;
|
||||||
}
|
}
|
||||||
|
void scale_by_2(Eigen::Ref<Eigen::MatrixXd> &v) {
|
||||||
|
v *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
To see why this is, refer to the section on :ref:`opaque` (although that
|
To see why this is, refer to the section on :ref:`opaque` (although that
|
||||||
section specifically covers STL data types, the underlying issue is the same).
|
section specifically covers STL data types, the underlying issue is the same).
|
||||||
|
@ -66,6 +66,22 @@ void init_eigen(py::module &m) {
|
|||||||
return x.block(start_row, start_col, block_rows, block_cols);
|
return x.block(start_row, start_col, block_rows, block_cols);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Returns a DiagonalMatrix with diagonal (1,2,3,...)
|
||||||
|
m.def("incr_diag", [](int k) {
|
||||||
|
Eigen::DiagonalMatrix<int, Eigen::Dynamic> m(k);
|
||||||
|
for (int i = 0; i < k; i++) m.diagonal()[i] = i+1;
|
||||||
|
return m;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Returns a SelfAdjointView referencing the lower triangle of m
|
||||||
|
m.def("symmetric_lower", [](const Eigen::MatrixXi &m) {
|
||||||
|
return m.selfadjointView<Eigen::Lower>();
|
||||||
|
});
|
||||||
|
// Returns a SelfAdjointView referencing the lower triangle of m
|
||||||
|
m.def("symmetric_upper", [](const Eigen::MatrixXi &m) {
|
||||||
|
return m.selfadjointView<Eigen::Upper>();
|
||||||
|
});
|
||||||
|
|
||||||
m.def("fixed_r", [mat]() -> FixedMatrixR {
|
m.def("fixed_r", [mat]() -> FixedMatrixR {
|
||||||
return FixedMatrixR(mat);
|
return FixedMatrixR(mat);
|
||||||
});
|
});
|
||||||
|
@ -14,6 +14,7 @@ from example import double_mat_cm, double_mat_rm
|
|||||||
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
||||||
from example import diagonal, diagonal_1, diagonal_n
|
from example import diagonal, diagonal_1, diagonal_n
|
||||||
from example import block
|
from example import block
|
||||||
|
from example import incr_diag, symmetric_upper, symmetric_lower
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
@ -88,3 +89,20 @@ for i in range(-5, 7):
|
|||||||
print("block(2,1,3,3) %s" % ("OK" if (block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]).all() else "FAILED"))
|
print("block(2,1,3,3) %s" % ("OK" if (block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]).all() else "FAILED"))
|
||||||
print("block(1,4,4,2) %s" % ("OK" if (block(ref, 1, 4, 4, 2) == ref[1:, 4:]).all() else "FAILED"))
|
print("block(1,4,4,2) %s" % ("OK" if (block(ref, 1, 4, 4, 2) == ref[1:, 4:]).all() else "FAILED"))
|
||||||
print("block(1,4,3,2) %s" % ("OK" if (block(ref, 1, 4, 3, 2) == ref[1:4, 4:]).all() else "FAILED"))
|
print("block(1,4,3,2) %s" % ("OK" if (block(ref, 1, 4, 3, 2) == ref[1:4, 4:]).all() else "FAILED"))
|
||||||
|
|
||||||
|
print("incr_diag %s" % ("OK" if (incr_diag(7) == np.diag([1,2,3,4,5,6,7])).all() else "FAILED"))
|
||||||
|
|
||||||
|
asymm = np.array([
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[9, 10,11,12],
|
||||||
|
[13,14,15,16]])
|
||||||
|
symm_lower = np.array(asymm)
|
||||||
|
symm_upper = np.array(asymm)
|
||||||
|
for i in range(4):
|
||||||
|
for j in range(i+1, 4):
|
||||||
|
symm_lower[i,j] = symm_lower[j,i]
|
||||||
|
symm_upper[j,i] = symm_upper[i,j]
|
||||||
|
|
||||||
|
print("symmetric_lower %s" % ("OK" if (symmetric_lower(asymm) == symm_lower).all() else "FAILED"))
|
||||||
|
print("symmetric_upper %s" % ("OK" if (symmetric_upper(asymm) == symm_upper).all() else "FAILED"))
|
||||||
|
@ -50,3 +50,6 @@ diagonal_n(6) OK
|
|||||||
block(2,1,3,3) OK
|
block(2,1,3,3) OK
|
||||||
block(1,4,4,2) OK
|
block(1,4,4,2) OK
|
||||||
block(1,4,3,2) OK
|
block(1,4,3,2) OK
|
||||||
|
incr_diag OK
|
||||||
|
symmetric_lower OK
|
||||||
|
symmetric_upper OK
|
||||||
|
@ -61,6 +61,19 @@ public:
|
|||||||
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
|
||||||
|
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
|
||||||
|
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
|
||||||
|
// SelfAdjointView fall into this category.
|
||||||
|
template <typename T> class is_eigen_base {
|
||||||
|
private:
|
||||||
|
template<typename Derived> static std::true_type test(const Eigen::EigenBase<Derived> &);
|
||||||
|
static std::false_type test(...);
|
||||||
|
public:
|
||||||
|
static constexpr bool value = !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value &&
|
||||||
|
decltype(test(std::declval<T>()))::value;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Type>
|
template<typename Type>
|
||||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
|
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
|
||||||
typedef typename Type::Scalar Scalar;
|
typedef typename Type::Scalar Scalar;
|
||||||
@ -164,11 +177,10 @@ protected:
|
|||||||
|
|
||||||
template<typename Type>
|
template<typename Type>
|
||||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
|
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
|
||||||
private:
|
protected:
|
||||||
using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
|
using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
|
||||||
using DerivedCaster = type_caster<Derived>;
|
using DerivedCaster = type_caster<Derived>;
|
||||||
DerivedCaster derived_caster;
|
DerivedCaster derived_caster;
|
||||||
protected:
|
|
||||||
std::unique_ptr<Type> value;
|
std::unique_ptr<Type> value;
|
||||||
public:
|
public:
|
||||||
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
|
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
|
||||||
@ -182,6 +194,25 @@ public:
|
|||||||
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can
|
||||||
|
// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that.
|
||||||
|
template <typename Type>
|
||||||
|
struct type_caster<Type, typename std::enable_if<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>::type> {
|
||||||
|
protected:
|
||||||
|
using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>;
|
||||||
|
using MatrixCaster = type_caster<Matrix>;
|
||||||
|
public:
|
||||||
|
[[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); }
|
||||||
|
static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); }
|
||||||
|
static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); }
|
||||||
|
|
||||||
|
static PYBIND11_DESCR name() { return MatrixCaster::name(); }
|
||||||
|
|
||||||
|
[[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
|
||||||
|
[[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
|
||||||
|
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Type>
|
template<typename Type>
|
||||||
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
|
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
|
||||||
typedef typename Type::Scalar Scalar;
|
typedef typename Type::Scalar Scalar;
|
||||||
|
Loading…
Reference in New Issue
Block a user