Merge pull request #315 from jagerman/eigen-stride-fix

Fix eigen copying of non-standard stride values
This commit is contained in:
Wenzel Jakob 2016-08-04 19:47:17 +02:00 committed by GitHub
commit 19637536ac
4 changed files with 42 additions and 5 deletions

View File

@ -56,6 +56,16 @@ void init_eigen(py::module &m) {
m.def("cholesky5", &cholesky5); m.def("cholesky5", &cholesky5);
m.def("cholesky6", &cholesky6); m.def("cholesky6", &cholesky6);
// Returns diagonals: a vector-like object with an inner stride != 1
m.def("diagonal", [](const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.diagonal(); });
m.def("diagonal_1", [](const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.diagonal<1>(); });
m.def("diagonal_n", [](const Eigen::Ref<const Eigen::MatrixXd> &x, int index) { return x.diagonal(index); });
// Return a block of a matrix (gives non-standard strides)
m.def("block", [](const Eigen::Ref<const Eigen::MatrixXd> &x, int start_row, int start_col, int block_rows, int block_cols) {
return x.block(start_row, start_col, block_rows, block_cols);
});
m.def("fixed_r", [mat]() -> FixedMatrixR { m.def("fixed_r", [mat]() -> FixedMatrixR {
return FixedMatrixR(mat); return FixedMatrixR(mat);
}); });

View File

@ -12,6 +12,8 @@ from example import sparse_passthrough_r, sparse_passthrough_c
from example import double_row, double_col from example import double_row, double_col
from example import double_mat_cm, double_mat_rm from example import double_mat_cm, double_mat_rm
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6 from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
from example import diagonal, diagonal_1, diagonal_n
from example import block
try: try:
import numpy as np import numpy as np
import scipy import scipy
@ -78,3 +80,11 @@ for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY")) print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
i += 1 i += 1
print("diagonal() %s" % ("OK" if (diagonal(ref) == ref.diagonal()).all() else "FAILED"))
print("diagonal_1() %s" % ("OK" if (diagonal_1(ref) == ref.diagonal(1)).all() else "FAILED"))
for i in range(-5, 7):
print("diagonal_n(%d) %s" % (i, "OK" if (diagonal_n(ref, i) == ref.diagonal(i)).all() else "FAILED"))
print("block(2,1,3,3) %s" % ("OK" if (block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]).all() else "FAILED"))
print("block(1,4,4,2) %s" % ("OK" if (block(ref, 1, 4, 4, 2) == ref[1:, 4:]).all() else "FAILED"))
print("block(1,4,3,2) %s" % ("OK" if (block(ref, 1, 4, 3, 2) == ref[1:4, 4:]).all() else "FAILED"))

View File

@ -33,3 +33,20 @@ cholesky3 OK
cholesky4 OK cholesky4 OK
cholesky5 OK cholesky5 OK
cholesky6 OK cholesky6 OK
diagonal() OK
diagonal_1() OK
diagonal_n(-5) OK
diagonal_n(-4) OK
diagonal_n(-3) OK
diagonal_n(-2) OK
diagonal_n(-1) OK
diagonal_n(0) OK
diagonal_n(1) OK
diagonal_n(2) OK
diagonal_n(3) OK
diagonal_n(4) OK
diagonal_n(5) OK
diagonal_n(6) OK
block(2,1,3,3) OK
block(1,4,4,2) OK
block(1,4,3,2) OK

View File

@ -40,8 +40,8 @@ public:
static constexpr bool value = decltype(test(std::declval<T>()))::value; static constexpr bool value = decltype(test(std::declval<T>()))::value;
}; };
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load // Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, so it needs a special
// it (since there is no reference!), but we can cast from it. // type_caster to handle argument copying/forwarding.
template <typename T> class is_eigen_ref { template <typename T> class is_eigen_ref {
private: private:
template<typename Derived> static typename std::enable_if< template<typename Derived> static typename std::enable_if<
@ -126,7 +126,7 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
/* Buffer dimensions */ /* Buffer dimensions */
{ (size_t) src.size() }, { (size_t) src.size() },
/* Strides (in bytes) for each index */ /* Strides (in bytes) for each index */
{ sizeof(Scalar) } { sizeof(Scalar) * src.innerStride() }
)).release(); )).release();
} else { } else {
return array(buffer_info( return array(buffer_info(
@ -142,8 +142,8 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
{ (size_t) src.rows(), { (size_t) src.rows(),
(size_t) src.cols() }, (size_t) src.cols() },
/* Strides (in bytes) for each index */ /* Strides (in bytes) for each index */
{ sizeof(Scalar) * (rowMajor ? (size_t) src.cols() : 1), { sizeof(Scalar) * src.rowStride(),
sizeof(Scalar) * (rowMajor ? 1 : (size_t) src.rows()) } sizeof(Scalar) * src.colStride() }
)).release(); )).release();
} }
} }