mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Merge pull request #315 from jagerman/eigen-stride-fix
Fix eigen copying of non-standard stride values
This commit is contained in:
commit
19637536ac
@ -56,6 +56,16 @@ void init_eigen(py::module &m) {
|
||||
m.def("cholesky5", &cholesky5);
|
||||
m.def("cholesky6", &cholesky6);
|
||||
|
||||
// Returns diagonals: a vector-like object with an inner stride != 1
|
||||
m.def("diagonal", [](const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.diagonal(); });
|
||||
m.def("diagonal_1", [](const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.diagonal<1>(); });
|
||||
m.def("diagonal_n", [](const Eigen::Ref<const Eigen::MatrixXd> &x, int index) { return x.diagonal(index); });
|
||||
|
||||
// Return a block of a matrix (gives non-standard strides)
|
||||
m.def("block", [](const Eigen::Ref<const Eigen::MatrixXd> &x, int start_row, int start_col, int block_rows, int block_cols) {
|
||||
return x.block(start_row, start_col, block_rows, block_cols);
|
||||
});
|
||||
|
||||
m.def("fixed_r", [mat]() -> FixedMatrixR {
|
||||
return FixedMatrixR(mat);
|
||||
});
|
||||
|
@ -12,6 +12,8 @@ from example import sparse_passthrough_r, sparse_passthrough_c
|
||||
from example import double_row, double_col
|
||||
from example import double_mat_cm, double_mat_rm
|
||||
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
||||
from example import diagonal, diagonal_1, diagonal_n
|
||||
from example import block
|
||||
try:
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -78,3 +80,11 @@ for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
|
||||
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
|
||||
i += 1
|
||||
|
||||
print("diagonal() %s" % ("OK" if (diagonal(ref) == ref.diagonal()).all() else "FAILED"))
|
||||
print("diagonal_1() %s" % ("OK" if (diagonal_1(ref) == ref.diagonal(1)).all() else "FAILED"))
|
||||
for i in range(-5, 7):
|
||||
print("diagonal_n(%d) %s" % (i, "OK" if (diagonal_n(ref, i) == ref.diagonal(i)).all() else "FAILED"))
|
||||
|
||||
print("block(2,1,3,3) %s" % ("OK" if (block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]).all() else "FAILED"))
|
||||
print("block(1,4,4,2) %s" % ("OK" if (block(ref, 1, 4, 4, 2) == ref[1:, 4:]).all() else "FAILED"))
|
||||
print("block(1,4,3,2) %s" % ("OK" if (block(ref, 1, 4, 3, 2) == ref[1:4, 4:]).all() else "FAILED"))
|
||||
|
@ -33,3 +33,20 @@ cholesky3 OK
|
||||
cholesky4 OK
|
||||
cholesky5 OK
|
||||
cholesky6 OK
|
||||
diagonal() OK
|
||||
diagonal_1() OK
|
||||
diagonal_n(-5) OK
|
||||
diagonal_n(-4) OK
|
||||
diagonal_n(-3) OK
|
||||
diagonal_n(-2) OK
|
||||
diagonal_n(-1) OK
|
||||
diagonal_n(0) OK
|
||||
diagonal_n(1) OK
|
||||
diagonal_n(2) OK
|
||||
diagonal_n(3) OK
|
||||
diagonal_n(4) OK
|
||||
diagonal_n(5) OK
|
||||
diagonal_n(6) OK
|
||||
block(2,1,3,3) OK
|
||||
block(1,4,4,2) OK
|
||||
block(1,4,3,2) OK
|
||||
|
@ -40,8 +40,8 @@ public:
|
||||
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
||||
};
|
||||
|
||||
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
|
||||
// it (since there is no reference!), but we can cast from it.
|
||||
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, so it needs a special
|
||||
// type_caster to handle argument copying/forwarding.
|
||||
template <typename T> class is_eigen_ref {
|
||||
private:
|
||||
template<typename Derived> static typename std::enable_if<
|
||||
@ -126,7 +126,7 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
||||
/* Buffer dimensions */
|
||||
{ (size_t) src.size() },
|
||||
/* Strides (in bytes) for each index */
|
||||
{ sizeof(Scalar) }
|
||||
{ sizeof(Scalar) * src.innerStride() }
|
||||
)).release();
|
||||
} else {
|
||||
return array(buffer_info(
|
||||
@ -142,8 +142,8 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
||||
{ (size_t) src.rows(),
|
||||
(size_t) src.cols() },
|
||||
/* Strides (in bytes) for each index */
|
||||
{ sizeof(Scalar) * (rowMajor ? (size_t) src.cols() : 1),
|
||||
sizeof(Scalar) * (rowMajor ? 1 : (size_t) src.rows()) }
|
||||
{ sizeof(Scalar) * src.rowStride(),
|
||||
sizeof(Scalar) * src.colStride() }
|
||||
)).release();
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user