Verified Commit 1f2ae251 authored by Justin Carpentier's avatar Justin Carpentier
Browse files

cholesky: support solve with input matrices

parent 51b79afe
Pipeline #15520 passed with stage
in 14 minutes and 23 seconds
/*
* Copyright 2020 INRIA
* Copyright 2020-2021 INRIA
*/
#ifndef __eigenpy_decomposition_ldlt_hpp__
......@@ -23,7 +23,8 @@ namespace eigenpy
typedef _MatrixType MatrixType;
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorType;
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorXs;
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic,MatrixType::Options> MatrixXs;
typedef Eigen::LDLT<MatrixType> Solver;
template<class PyClass>
......@@ -55,7 +56,7 @@ namespace eigenpy
"Returns the LDLT decomposition matrix.",
bp::return_internal_reference<>())
.def("rankUpdate",(Solver & (Solver::*)(const Eigen::MatrixBase<VectorType> &, const RealScalar &))&Solver::template rankUpdate<VectorType>,
.def("rankUpdate",(Solver & (Solver::*)(const Eigen::MatrixBase<VectorXs> &, const RealScalar &))&Solver::template rankUpdate<VectorXs>,
bp::args("self","vector","sigma"),
bp::return_self<>())
......@@ -78,8 +79,10 @@ namespace eigenpy
#endif
.def("reconstructedMatrix",&Solver::reconstructedMatrix,bp::arg("self"),
"Returns the matrix represented by the decomposition, i.e., it returns the product: L L^*. This function is provided for debug purpose.")
.def("solve",&solve<VectorType>,bp::args("self","b"),
.def("solve",&solve<VectorXs>,bp::args("self","b"),
"Returns the solution x of A x = b using the current decomposition of A.")
.def("solve",&solve<MatrixXs>,bp::args("self","B"),
"Returns the solution X of A X = B using the current decomposition of A where B is a right hand side matrix.")
.def("setZero",&Solver::setZero,bp::arg("self"),
"Clear any existing decomposition.")
......@@ -107,7 +110,7 @@ namespace eigenpy
static MatrixType matrixL(const Solver & self) { return self.matrixL(); }
static MatrixType matrixU(const Solver & self) { return self.matrixU(); }
static VectorType vectorD(const Solver & self) { return self.vectorD(); }
static VectorXs vectorD(const Solver & self) { return self.vectorD(); }
static MatrixType transpositionsP(const Solver & self)
{
......@@ -115,8 +118,8 @@ namespace eigenpy
self.matrixL().rows());
}
template<typename VectorType>
static VectorType solve(const Solver & self, const VectorType & vec)
template<typename MatrixOrVector>
static MatrixOrVector solve(const Solver & self, const MatrixOrVector & vec)
{
return self.solve(vec);
}
......
/*
* Copyright 2020 INRIA
* Copyright 2020-2021 INRIA
*/
#ifndef __eigenpy_decomposition_llt_hpp__
......@@ -23,7 +23,8 @@ namespace eigenpy
typedef _MatrixType MatrixType;
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorType;
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorXs;
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic,MatrixType::Options> MatrixXs;
typedef Eigen::LLT<MatrixType> Solver;
template<class PyClass>
......@@ -46,10 +47,10 @@ namespace eigenpy
bp::return_internal_reference<>())
#if EIGEN_VERSION_AT_LEAST(3,3,90)
.def("rankUpdate",(Solver& (Solver::*)(const VectorType &, const RealScalar &))&Solver::template rankUpdate<VectorType>,
.def("rankUpdate",(Solver& (Solver::*)(const VectorXs &, const RealScalar &))&Solver::template rankUpdate<VectorXs>,
bp::args("self","vector","sigma"), bp::return_self<>())
#else
.def("rankUpdate",(Solver (Solver::*)(const VectorType &, const RealScalar &))&Solver::template rankUpdate<VectorType>,
.def("rankUpdate",(Solver (Solver::*)(const VectorXs &, const RealScalar &))&Solver::template rankUpdate<VectorXs>,
bp::args("self","vector","sigma"))
#endif
......@@ -72,8 +73,10 @@ namespace eigenpy
#endif
.def("reconstructedMatrix",&Solver::reconstructedMatrix,bp::arg("self"),
"Returns the matrix represented by the decomposition, i.e., it returns the product: L L^*. This function is provided for debug purpose.")
.def("solve",&solve<VectorType>,bp::args("self","b"),
.def("solve",&solve<VectorXs>,bp::args("self","b"),
"Returns the solution x of A x = b using the current decomposition of A.")
.def("solve",&solve<MatrixXs>,bp::args("self","B"),
"Returns the solution X of A X = B using the current decomposition of A where B is a right hand side matrix.")
;
}
......@@ -99,8 +102,8 @@ namespace eigenpy
static MatrixType matrixL(const Solver & self) { return self.matrixL(); }
static MatrixType matrixU(const Solver & self) { return self.matrixU(); }
template<typename VectorType>
static VectorType solve(const Solver & self, const VectorType & vec)
template<typename MatrixOrVector>
static MatrixOrVector solve(const Solver & self, const MatrixOrVector & vec)
{
return self.solve(vec);
}
......
import eigenpy
eigenpy.switchToNumpyArray()
import numpy as np
import numpy.linalg as la
......@@ -16,3 +15,9 @@ D = ldlt.vectorD()
P = ldlt.transpositionsP()
assert eigenpy.is_approx(np.transpose(P).dot(L.dot(np.diag(D).dot(np.transpose(L).dot(P)))),A)
X = np.random.rand(dim,20)
B = A.dot(X)
X_est = ldlt.solve(B)
assert eigenpy.is_approx(X,X_est)
assert eigenpy.is_approx(A.dot(X_est),B)
import eigenpy
eigenpy.switchToNumpyArray()
import numpy as np
import numpy.linalg as la
......@@ -12,5 +11,10 @@ A = (A + A.T)*0.5 + np.diag(10. + np.random.rand(dim))
llt = eigenpy.LLT(A)
L = llt.matrixL()
assert eigenpy.is_approx(L.dot(np.transpose(L)),A)
X = np.random.rand(dim,20)
B = A.dot(X)
X_est = llt.solve(B)
assert eigenpy.is_approx(X,X_est)
assert eigenpy.is_approx(A.dot(X_est),B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment