mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 21:25:13 +00:00
Merge pull request #315 from jagerman/eigen-stride-fix
Fix eigen copying of non-standard stride values
This commit is contained in:
commit
19637536ac
@ -56,6 +56,16 @@ void init_eigen(py::module &m) {
|
|||||||
m.def("cholesky5", &cholesky5);
|
m.def("cholesky5", &cholesky5);
|
||||||
m.def("cholesky6", &cholesky6);
|
m.def("cholesky6", &cholesky6);
|
||||||
|
|
||||||
|
// Returns diagonals: a vector-like object with an inner stride != 1
|
||||||
|
m.def("diagonal", [](const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.diagonal(); });
|
||||||
|
m.def("diagonal_1", [](const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.diagonal<1>(); });
|
||||||
|
m.def("diagonal_n", [](const Eigen::Ref<const Eigen::MatrixXd> &x, int index) { return x.diagonal(index); });
|
||||||
|
|
||||||
|
// Return a block of a matrix (gives non-standard strides)
|
||||||
|
m.def("block", [](const Eigen::Ref<const Eigen::MatrixXd> &x, int start_row, int start_col, int block_rows, int block_cols) {
|
||||||
|
return x.block(start_row, start_col, block_rows, block_cols);
|
||||||
|
});
|
||||||
|
|
||||||
m.def("fixed_r", [mat]() -> FixedMatrixR {
|
m.def("fixed_r", [mat]() -> FixedMatrixR {
|
||||||
return FixedMatrixR(mat);
|
return FixedMatrixR(mat);
|
||||||
});
|
});
|
||||||
|
@ -12,6 +12,8 @@ from example import sparse_passthrough_r, sparse_passthrough_c
|
|||||||
from example import double_row, double_col
|
from example import double_row, double_col
|
||||||
from example import double_mat_cm, double_mat_rm
|
from example import double_mat_cm, double_mat_rm
|
||||||
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
||||||
|
from example import diagonal, diagonal_1, diagonal_n
|
||||||
|
from example import block
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
@ -78,3 +80,11 @@ for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
|
|||||||
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
|
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
print("diagonal() %s" % ("OK" if (diagonal(ref) == ref.diagonal()).all() else "FAILED"))
|
||||||
|
print("diagonal_1() %s" % ("OK" if (diagonal_1(ref) == ref.diagonal(1)).all() else "FAILED"))
|
||||||
|
for i in range(-5, 7):
|
||||||
|
print("diagonal_n(%d) %s" % (i, "OK" if (diagonal_n(ref, i) == ref.diagonal(i)).all() else "FAILED"))
|
||||||
|
|
||||||
|
print("block(2,1,3,3) %s" % ("OK" if (block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]).all() else "FAILED"))
|
||||||
|
print("block(1,4,4,2) %s" % ("OK" if (block(ref, 1, 4, 4, 2) == ref[1:, 4:]).all() else "FAILED"))
|
||||||
|
print("block(1,4,3,2) %s" % ("OK" if (block(ref, 1, 4, 3, 2) == ref[1:4, 4:]).all() else "FAILED"))
|
||||||
|
@ -33,3 +33,20 @@ cholesky3 OK
|
|||||||
cholesky4 OK
|
cholesky4 OK
|
||||||
cholesky5 OK
|
cholesky5 OK
|
||||||
cholesky6 OK
|
cholesky6 OK
|
||||||
|
diagonal() OK
|
||||||
|
diagonal_1() OK
|
||||||
|
diagonal_n(-5) OK
|
||||||
|
diagonal_n(-4) OK
|
||||||
|
diagonal_n(-3) OK
|
||||||
|
diagonal_n(-2) OK
|
||||||
|
diagonal_n(-1) OK
|
||||||
|
diagonal_n(0) OK
|
||||||
|
diagonal_n(1) OK
|
||||||
|
diagonal_n(2) OK
|
||||||
|
diagonal_n(3) OK
|
||||||
|
diagonal_n(4) OK
|
||||||
|
diagonal_n(5) OK
|
||||||
|
diagonal_n(6) OK
|
||||||
|
block(2,1,3,3) OK
|
||||||
|
block(1,4,4,2) OK
|
||||||
|
block(1,4,3,2) OK
|
||||||
|
@ -40,8 +40,8 @@ public:
|
|||||||
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
|
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, so it needs a special
|
||||||
// it (since there is no reference!), but we can cast from it.
|
// type_caster to handle argument copying/forwarding.
|
||||||
template <typename T> class is_eigen_ref {
|
template <typename T> class is_eigen_ref {
|
||||||
private:
|
private:
|
||||||
template<typename Derived> static typename std::enable_if<
|
template<typename Derived> static typename std::enable_if<
|
||||||
@ -126,7 +126,7 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
|||||||
/* Buffer dimensions */
|
/* Buffer dimensions */
|
||||||
{ (size_t) src.size() },
|
{ (size_t) src.size() },
|
||||||
/* Strides (in bytes) for each index */
|
/* Strides (in bytes) for each index */
|
||||||
{ sizeof(Scalar) }
|
{ sizeof(Scalar) * src.innerStride() }
|
||||||
)).release();
|
)).release();
|
||||||
} else {
|
} else {
|
||||||
return array(buffer_info(
|
return array(buffer_info(
|
||||||
@ -142,8 +142,8 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
|||||||
{ (size_t) src.rows(),
|
{ (size_t) src.rows(),
|
||||||
(size_t) src.cols() },
|
(size_t) src.cols() },
|
||||||
/* Strides (in bytes) for each index */
|
/* Strides (in bytes) for each index */
|
||||||
{ sizeof(Scalar) * (rowMajor ? (size_t) src.cols() : 1),
|
{ sizeof(Scalar) * src.rowStride(),
|
||||||
sizeof(Scalar) * (rowMajor ? 1 : (size_t) src.rows()) }
|
sizeof(Scalar) * src.colStride() }
|
||||||
)).release();
|
)).release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user