From 5fd5074a0b19ff6f28c85513e2ebe01a9b20948e Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Wed, 3 Aug 2016 16:50:22 -0400 Subject: [PATCH] Add support for Eigen::Ref<...> function arguments Eigen::Ref is a common way to pass eigen dense types without needing a template, e.g. the single definition `void func(Eigen::Ref x)` can be called with any double matrix-like object. The current pybind11 eigen support fails with internal errors if attempting to bind a function with an Eigen::Ref<...> argument because Eigen::Ref<...> satisfies the "is_eigen_dense" requirement, but can't compile if actually used: Eigen::Ref<...> itself is not default constructible, and so the argument std::tuple containing an Eigen::Ref<...> isn't constructible, which results in compilation failure. This commit adds support for Eigen::Ref<...> by giving it its own type_caster implementation which consists of an internal type_caster of the referenced type, load/cast methods that dispatch to the internal type_caster, and a unique_ptr to an Eigen::Ref<> instance that gets set during load(). There is, of course, no performance advantage for pybind11-using code of using Eigen::Ref<...>--we are allocating a matrix of the derived type when loading it--but this has the advantage of allowing pybind11 to bind transparently to C++ methods taking Eigen::Refs. --- example/eigen.cpp | 15 +++++++++++++++ example/eigen.py | 8 ++++++++ example/eigen.ref | 6 ++++++ include/pybind11/eigen.h | 35 ++++++++++++++++++++++++++++++++++- 4 files changed, 63 insertions(+), 1 deletion(-) diff --git a/example/eigen.cpp b/example/eigen.cpp index f99ae3a40..728b575fd 100644 --- a/example/eigen.cpp +++ b/example/eigen.cpp @@ -9,6 +9,7 @@ #include "example.h" #include +#include Eigen::VectorXf double_col(const Eigen::VectorXf& x) { return 2.0f * x; } @@ -19,6 +20,14 @@ Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x) Eigen::MatrixXf double_mat_cm(const Eigen::MatrixXf& x) { return 2.0f * x; } +// Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended +Eigen::MatrixXd cholesky1(Eigen::Ref &x) { return x.llt().matrixL(); } +Eigen::MatrixXd cholesky2(const Eigen::Ref &x) { return x.llt().matrixL(); } +Eigen::MatrixXd cholesky3(const Eigen::Ref &x) { return x.llt().matrixL(); } +Eigen::MatrixXd cholesky4(Eigen::Ref &x) { return x.llt().matrixL(); } +Eigen::MatrixXd cholesky5(Eigen::Ref x) { return x.llt().matrixL(); } +Eigen::MatrixXd cholesky6(Eigen::Ref x) { return x.llt().matrixL(); } + typedef Eigen::Matrix MatrixXfRowMajor; MatrixXfRowMajor double_mat_rm(const MatrixXfRowMajor& x) { return 2.0f * x; } @@ -40,6 +49,12 @@ void init_eigen(py::module &m) { m.def("double_row", &double_row); m.def("double_mat_cm", &double_mat_cm); m.def("double_mat_rm", &double_mat_rm); + m.def("cholesky1", &cholesky1); + m.def("cholesky2", &cholesky2); + m.def("cholesky3", &cholesky3); + m.def("cholesky4", &cholesky4); + m.def("cholesky5", &cholesky5); + m.def("cholesky6", &cholesky6); m.def("fixed_r", [mat]() -> FixedMatrixR { return FixedMatrixR(mat); diff --git a/example/eigen.py b/example/eigen.py index e69605d20..6cdc3940b 100644 --- a/example/eigen.py +++ b/example/eigen.py @@ -11,6 +11,7 @@ from example import sparse_r, sparse_c from example import sparse_passthrough_r, sparse_passthrough_c from example import double_row, double_col from example import double_mat_cm, double_mat_rm +from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6 try: import numpy as np import scipy @@ -70,3 +71,10 @@ slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] for slice_idx, ref_mat in enumerate(slices): print("double_mat_cm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_cm(ref_mat), 2.0 * ref_mat))) print("double_mat_rm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_rm(ref_mat), 2.0 * ref_mat))) + +i = 1 +for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]: + mymat = chol(np.array([[1,2,4], [2,13,23], [4,23,77]])) + print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY")) + i += 1 + diff --git a/example/eigen.ref b/example/eigen.ref index 03091cc24..93e88adb9 100644 --- a/example/eigen.ref +++ b/example/eigen.ref @@ -27,3 +27,9 @@ double_mat_cm(1) = OK double_mat_rm(1) = OK double_mat_cm(2) = OK double_mat_rm(2) = OK +cholesky1 OK +cholesky2 OK +cholesky3 OK +cholesky4 OK +cholesky5 OK +cholesky6 OK diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 987b54702..b9f22b049 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -40,6 +40,19 @@ public: static constexpr bool value = decltype(test(std::declval()))::value; }; +// Eigen::Ref satisfies is_eigen_dense, but isn't constructible, which means we can't load +// it (since there is no reference!), but we can cast from it. +template class is_eigen_ref { +private: + template static typename std::enable_if< + std::is_same::type, Eigen::Ref>::value, + Derived>::type test(const Eigen::Ref &); + static void test(...); +public: + typedef decltype(test(std::declval())) Derived; + static constexpr bool value = !std::is_void::value; +}; + template class is_eigen_sparse { private: template static std::true_type test(const Eigen::SparseMatrixBase &); @@ -49,7 +62,7 @@ public: }; template -struct type_caster::value>::type> { +struct type_caster::value && !is_eigen_ref::value>::type> { typedef typename Type::Scalar Scalar; static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; static constexpr bool isVector = Type::IsVectorAtCompileTime; @@ -149,6 +162,26 @@ protected: static PYBIND11_DESCR cols() { return _(); } }; +template +struct type_caster::value && is_eigen_ref::value>::type> { +private: + using Derived = typename std::remove_const::Derived>::type; + using DerivedCaster = type_caster; + DerivedCaster derived_caster; +protected: + std::unique_ptr value; +public: + bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; } + static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); } + static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); } + + static PYBIND11_DESCR name() { return DerivedCaster::name(); } + + operator Type*() { return value.get(); } + operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; } + template using cast_op_type = pybind11::detail::cast_op_type<_T>; +}; + template struct type_caster::value>::type> { typedef typename Type::Scalar Scalar;