diff --git a/example/eigen.cpp b/example/eigen.cpp index b6fa24a42..f99ae3a40 100644 --- a/example/eigen.cpp +++ b/example/eigen.cpp @@ -10,6 +10,19 @@ #include "example.h" #include +Eigen::VectorXf double_col(const Eigen::VectorXf& x) +{ return 2.0f * x; } + +Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x) +{ return 2.0f * x; } + +Eigen::MatrixXf double_mat_cm(const Eigen::MatrixXf& x) +{ return 2.0f * x; } + +typedef Eigen::Matrix MatrixXfRowMajor; +MatrixXfRowMajor double_mat_rm(const MatrixXfRowMajor& x) +{ return 2.0f * x; } + void init_eigen(py::module &m) { typedef Eigen::Matrix FixedMatrixR; typedef Eigen::Matrix FixedMatrixC; @@ -23,6 +36,11 @@ void init_eigen(py::module &m) { mat << 0, 3, 0, 0, 0, 11, 22, 0, 0, 0, 17, 11, 7, 5, 0, 1, 0, 11, 0, 0, 0, 0, 0, 11, 0, 0, 14, 0, 8, 11; + m.def("double_col", &double_col); + m.def("double_row", &double_row); + m.def("double_mat_cm", &double_mat_cm); + m.def("double_mat_rm", &double_mat_rm); + m.def("fixed_r", [mat]() -> FixedMatrixR { return FixedMatrixR(mat); }); diff --git a/example/eigen.py b/example/eigen.py index accaf236a..9c4e1ef16 100644 --- a/example/eigen.py +++ b/example/eigen.py @@ -9,6 +9,8 @@ from example import dense_r, dense_c from example import dense_passthrough_r, dense_passthrough_c from example import sparse_r, sparse_c from example import sparse_passthrough_r, sparse_passthrough_c +from example import double_row, double_col +from example import double_mat_cm, double_mat_rm import numpy as np ref = np.array( @@ -42,3 +44,22 @@ print("pt_r(sparse_r) = %s" % check(sparse_passthrough_r(sparse_r()))) print("pt_c(sparse_c) = %s" % check(sparse_passthrough_c(sparse_c()))) print("pt_r(sparse_c) = %s" % check(sparse_passthrough_r(sparse_c()))) print("pt_c(sparse_r) = %s" % check(sparse_passthrough_c(sparse_r()))) + +def check_got_vs_ref(got_x, ref_x): + return 'OK' if np.array_equal(got_x, ref_x) else 'NOT OK' + +counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3)) +first_row = counting_mat[0, :] +first_col = counting_mat[:, 0] + +print("double_row(first_row) = %s" % check_got_vs_ref(double_row(first_row), 2.0 * first_row)) +print("double_col(first_row) = %s" % check_got_vs_ref(double_col(first_row), 2.0 * first_row)) +print("double_row(first_col) = %s" % check_got_vs_ref(double_row(first_col), 2.0 * first_col)) +print("double_col(first_col) = %s" % check_got_vs_ref(double_col(first_col), 2.0 * first_col)) + +counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3)) +slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] + +for slice_idx, ref_mat in enumerate(slices): + print("double_mat_cm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_cm(ref_mat), 2.0 * ref_mat))) + print("double_mat_rm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_rm(ref_mat), 2.0 * ref_mat))) diff --git a/example/eigen.ref b/example/eigen.ref index b87f8ede3..bac73be9f 100644 --- a/example/eigen.ref +++ b/example/eigen.ref @@ -16,3 +16,13 @@ pt_r(sparse_r) = OK pt_c(sparse_c) = OK pt_r(sparse_c) = OK pt_c(sparse_r) = OK +double_row(first_row) = OK +double_col(first_row) = OK +double_row(first_col) = OK +double_col(first_col) = OK +double_mat_cm(0) = OK +double_mat_rm(0) = OK +double_mat_cm(1) = OK +double_mat_rm(1) = OK +double_mat_cm(2) = OK +double_mat_rm(2) = OK diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 718107947..ecad2d547 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -61,7 +61,7 @@ struct type_caster::value>::t buffer_info info = buffer.request(); if (info.ndim == 1) { - typedef Eigen::Stride Strides; + typedef Eigen::InnerStride<> Strides; if (!isVector && !(Type::RowsAtCompileTime == Eigen::Dynamic && Type::ColsAtCompileTime == Eigen::Dynamic)) @@ -71,10 +71,13 @@ struct type_caster::value>::t info.shape[0] != (size_t) Type::SizeAtCompileTime) return false; - auto strides = Strides(info.strides[0] / sizeof(Scalar), 0); + auto strides = Strides(info.strides[0] / sizeof(Scalar)); + + Strides::Index n_elts = info.shape[0]; + Strides::Index unity = 1; value = Eigen::Map( - (Scalar *) info.ptr, typename Strides::Index(info.shape[0]), 1, strides); + (Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides); } else if (info.ndim == 2) { typedef Eigen::Stride Strides;