mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-31 15:20:34 +00:00
Merge pull request #312 from jagerman/eigen-ref-args
Add support for Eigen::Ref<...> function arguments
This commit is contained in:
commit
39ff2d0140
@ -9,6 +9,7 @@
|
||||
|
||||
#include "example.h"
|
||||
#include <pybind11/eigen.h>
|
||||
#include <Eigen/Cholesky>
|
||||
|
||||
Eigen::VectorXf double_col(const Eigen::VectorXf& x)
|
||||
{ return 2.0f * x; }
|
||||
@ -19,6 +20,14 @@ Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x)
|
||||
Eigen::MatrixXf double_mat_cm(const Eigen::MatrixXf& x)
|
||||
{ return 2.0f * x; }
|
||||
|
||||
// Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended
|
||||
Eigen::MatrixXd cholesky1(Eigen::Ref<Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
|
||||
Eigen::MatrixXd cholesky2(const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
|
||||
Eigen::MatrixXd cholesky3(const Eigen::Ref<Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
|
||||
Eigen::MatrixXd cholesky4(Eigen::Ref<const Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
|
||||
Eigen::MatrixXd cholesky5(Eigen::Ref<Eigen::MatrixXd> x) { return x.llt().matrixL(); }
|
||||
Eigen::MatrixXd cholesky6(Eigen::Ref<const Eigen::MatrixXd> x) { return x.llt().matrixL(); }
|
||||
|
||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXfRowMajor;
|
||||
MatrixXfRowMajor double_mat_rm(const MatrixXfRowMajor& x)
|
||||
{ return 2.0f * x; }
|
||||
@ -40,6 +49,12 @@ void init_eigen(py::module &m) {
|
||||
m.def("double_row", &double_row);
|
||||
m.def("double_mat_cm", &double_mat_cm);
|
||||
m.def("double_mat_rm", &double_mat_rm);
|
||||
m.def("cholesky1", &cholesky1);
|
||||
m.def("cholesky2", &cholesky2);
|
||||
m.def("cholesky3", &cholesky3);
|
||||
m.def("cholesky4", &cholesky4);
|
||||
m.def("cholesky5", &cholesky5);
|
||||
m.def("cholesky6", &cholesky6);
|
||||
|
||||
m.def("fixed_r", [mat]() -> FixedMatrixR {
|
||||
return FixedMatrixR(mat);
|
||||
|
@ -11,6 +11,7 @@ from example import sparse_r, sparse_c
|
||||
from example import sparse_passthrough_r, sparse_passthrough_c
|
||||
from example import double_row, double_col
|
||||
from example import double_mat_cm, double_mat_rm
|
||||
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
||||
try:
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -70,3 +71,10 @@ slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]]
|
||||
for slice_idx, ref_mat in enumerate(slices):
|
||||
print("double_mat_cm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_cm(ref_mat), 2.0 * ref_mat)))
|
||||
print("double_mat_rm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_rm(ref_mat), 2.0 * ref_mat)))
|
||||
|
||||
i = 1
|
||||
for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
|
||||
mymat = chol(np.array([[1,2,4], [2,13,23], [4,23,77]]))
|
||||
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
|
||||
i += 1
|
||||
|
||||
|
@ -27,3 +27,9 @@ double_mat_cm(1) = OK
|
||||
double_mat_rm(1) = OK
|
||||
double_mat_cm(2) = OK
|
||||
double_mat_rm(2) = OK
|
||||
cholesky1 OK
|
||||
cholesky2 OK
|
||||
cholesky3 OK
|
||||
cholesky4 OK
|
||||
cholesky5 OK
|
||||
cholesky6 OK
|
||||
|
@ -40,6 +40,19 @@ public:
|
||||
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
||||
};
|
||||
|
||||
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
|
||||
// it (since there is no reference!), but we can cast from it.
|
||||
template <typename T> class is_eigen_ref {
|
||||
private:
|
||||
template<typename Derived> static typename std::enable_if<
|
||||
std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value,
|
||||
Derived>::type test(const Eigen::Ref<Derived> &);
|
||||
static void test(...);
|
||||
public:
|
||||
typedef decltype(test(std::declval<T>())) Derived;
|
||||
static constexpr bool value = !std::is_void<Derived>::value;
|
||||
};
|
||||
|
||||
template <typename T> class is_eigen_sparse {
|
||||
private:
|
||||
template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
|
||||
@ -49,7 +62,7 @@ public:
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
|
||||
typedef typename Type::Scalar Scalar;
|
||||
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
|
||||
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
||||
@ -149,6 +162,26 @@ protected:
|
||||
static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
|
||||
private:
|
||||
using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
|
||||
using DerivedCaster = type_caster<Derived>;
|
||||
DerivedCaster derived_caster;
|
||||
protected:
|
||||
std::unique_ptr<Type> value;
|
||||
public:
|
||||
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
|
||||
static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
|
||||
static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
|
||||
|
||||
static PYBIND11_DESCR name() { return DerivedCaster::name(); }
|
||||
|
||||
operator Type*() { return value.get(); }
|
||||
operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
|
||||
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
|
||||
typedef typename Type::Scalar Scalar;
|
||||
|
Loading…
Reference in New Issue
Block a user