diff --git a/example/eigen.cpp b/example/eigen.cpp index b6fa24a42..f99ae3a40 100644 --- a/example/eigen.cpp +++ b/example/eigen.cpp @@ -10,6 +10,19 @@ #include "example.h" #include +Eigen::VectorXf double_col(const Eigen::VectorXf& x) +{ return 2.0f * x; } + +Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x) +{ return 2.0f * x; } + +Eigen::MatrixXf double_mat_cm(const Eigen::MatrixXf& x) +{ return 2.0f * x; } + +typedef Eigen::Matrix MatrixXfRowMajor; +MatrixXfRowMajor double_mat_rm(const MatrixXfRowMajor& x) +{ return 2.0f * x; } + void init_eigen(py::module &m) { typedef Eigen::Matrix FixedMatrixR; typedef Eigen::Matrix FixedMatrixC; @@ -23,6 +36,11 @@ void init_eigen(py::module &m) { mat << 0, 3, 0, 0, 0, 11, 22, 0, 0, 0, 17, 11, 7, 5, 0, 1, 0, 11, 0, 0, 0, 0, 0, 11, 0, 0, 14, 0, 8, 11; + m.def("double_col", &double_col); + m.def("double_row", &double_row); + m.def("double_mat_cm", &double_mat_cm); + m.def("double_mat_rm", &double_mat_rm); + m.def("fixed_r", [mat]() -> FixedMatrixR { return FixedMatrixR(mat); }); diff --git a/example/eigen.py b/example/eigen.py index accaf236a..7ff9aed5d 100644 --- a/example/eigen.py +++ b/example/eigen.py @@ -9,6 +9,8 @@ from example import dense_r, dense_c from example import dense_passthrough_r, dense_passthrough_c from example import sparse_r, sparse_c from example import sparse_passthrough_r, sparse_passthrough_c +from example import double_row, double_col +from example import double_mat_cm, double_mat_rm import numpy as np ref = np.array( @@ -20,7 +22,9 @@ ref = np.array( def check(mat): - return 'OK' if np.sum(mat - ref) == 0 else 'NOT OK' + return 'OK' if np.sum(abs(mat - ref)) == 0 else 'NOT OK' + +print("should_give_NOT_OK = %s" % check(ref[:, ::-1])) print("fixed_r = %s" % check(fixed_r())) print("fixed_c = %s" % check(fixed_c())) @@ -42,3 +46,22 @@ print("pt_r(sparse_r) = %s" % check(sparse_passthrough_r(sparse_r()))) print("pt_c(sparse_c) = %s" % check(sparse_passthrough_c(sparse_c()))) print("pt_r(sparse_c) = %s" % check(sparse_passthrough_r(sparse_c()))) print("pt_c(sparse_r) = %s" % check(sparse_passthrough_c(sparse_r()))) + +def check_got_vs_ref(got_x, ref_x): + return 'OK' if np.array_equal(got_x, ref_x) else 'NOT OK' + +counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3)) +first_row = counting_mat[0, :] +first_col = counting_mat[:, 0] + +print("double_row(first_row) = %s" % check_got_vs_ref(double_row(first_row), 2.0 * first_row)) +print("double_col(first_row) = %s" % check_got_vs_ref(double_col(first_row), 2.0 * first_row)) +print("double_row(first_col) = %s" % check_got_vs_ref(double_row(first_col), 2.0 * first_col)) +print("double_col(first_col) = %s" % check_got_vs_ref(double_col(first_col), 2.0 * first_col)) + +counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3)) +slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] + +for slice_idx, ref_mat in enumerate(slices): + print("double_mat_cm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_cm(ref_mat), 2.0 * ref_mat))) + print("double_mat_rm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_rm(ref_mat), 2.0 * ref_mat))) diff --git a/example/eigen.ref b/example/eigen.ref index b87f8ede3..03091cc24 100644 --- a/example/eigen.ref +++ b/example/eigen.ref @@ -1,3 +1,4 @@ +should_give_NOT_OK = NOT OK fixed_r = OK fixed_c = OK pt_r(fixed_r) = OK @@ -16,3 +17,13 @@ pt_r(sparse_r) = OK pt_c(sparse_c) = OK pt_r(sparse_c) = OK pt_c(sparse_r) = OK +double_row(first_row) = OK +double_col(first_row) = OK +double_row(first_col) = OK +double_col(first_col) = OK +double_mat_cm(0) = OK +double_mat_rm(0) = OK +double_mat_cm(1) = OK +double_mat_rm(1) = OK +double_mat_cm(2) = OK +double_mat_rm(2) = OK diff --git a/example/run_test.py b/example/run_test.py index 70ce4a6c0..c31ea0989 100755 --- a/example/run_test.py +++ b/example/run_test.py @@ -68,6 +68,6 @@ else: print('Test "%s" FAILED!' % name) print('--- output') print('+++ reference') - print(''.join(difflib.ndiff(output.splitlines(keepends=True), - reference.splitlines(keepends=True)))) + print(''.join(difflib.ndiff(output.splitlines(True), + reference.splitlines(True)))) exit(-1) diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index fa4e9f642..7a0fe9302 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -61,7 +61,7 @@ struct type_caster::value>::t buffer_info info = buffer.request(); if (info.ndim == 1) { - typedef Eigen::Stride Strides; + typedef Eigen::InnerStride<> Strides; if (!isVector && !(Type::RowsAtCompileTime == Eigen::Dynamic && Type::ColsAtCompileTime == Eigen::Dynamic)) @@ -71,10 +71,13 @@ struct type_caster::value>::t info.shape[0] != (size_t) Type::SizeAtCompileTime) return false; - auto strides = Strides(info.strides[0] / sizeof(Scalar), 0); + auto strides = Strides(info.strides[0] / sizeof(Scalar)); + + Strides::Index n_elts = info.shape[0]; + Strides::Index unity = 1; value = Eigen::Map( - (Scalar *) info.ptr, typename Strides::Index(info.shape[0]), 1, strides); + (Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides); } else if (info.ndim == 2) { typedef Eigen::Stride Strides;