Skip to content

Commit 1f2ae25

Browse files
committed
cholesky: support solve with input matrices
1 parent 51b79af commit 1f2ae25

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

include/eigenpy/decompositions/LDLT.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 INRIA
2+
* Copyright 2020-2021 INRIA
33
*/
44

55
#ifndef __eigenpy_decomposition_ldlt_hpp__
@@ -23,7 +23,8 @@ namespace eigenpy
2323
typedef _MatrixType MatrixType;
2424
typedef typename MatrixType::Scalar Scalar;
2525
typedef typename MatrixType::RealScalar RealScalar;
26-
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorType;
26+
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorXs;
27+
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic,MatrixType::Options> MatrixXs;
2728
typedef Eigen::LDLT<MatrixType> Solver;
2829

2930
template<class PyClass>
@@ -55,7 +56,7 @@ namespace eigenpy
5556
"Returns the LDLT decomposition matrix.",
5657
bp::return_internal_reference<>())
5758

58-
.def("rankUpdate",(Solver & (Solver::*)(const Eigen::MatrixBase<VectorType> &, const RealScalar &))&Solver::template rankUpdate<VectorType>,
59+
.def("rankUpdate",(Solver & (Solver::*)(const Eigen::MatrixBase<VectorXs> &, const RealScalar &))&Solver::template rankUpdate<VectorXs>,
5960
bp::args("self","vector","sigma"),
6061
bp::return_self<>())
6162

@@ -78,8 +79,10 @@ namespace eigenpy
7879
#endif
7980
.def("reconstructedMatrix",&Solver::reconstructedMatrix,bp::arg("self"),
8081
"Returns the matrix represented by the decomposition, i.e., it returns the product: L L^*. This function is provided for debug purpose.")
81-
.def("solve",&solve<VectorType>,bp::args("self","b"),
82+
.def("solve",&solve<VectorXs>,bp::args("self","b"),
8283
"Returns the solution x of A x = b using the current decomposition of A.")
84+
.def("solve",&solve<MatrixXs>,bp::args("self","B"),
85+
"Returns the solution X of A X = B using the current decomposition of A where B is a right hand side matrix.")
8386

8487
.def("setZero",&Solver::setZero,bp::arg("self"),
8588
"Clear any existing decomposition.")
@@ -107,16 +110,16 @@ namespace eigenpy
107110

108111
static MatrixType matrixL(const Solver & self) { return self.matrixL(); }
109112
static MatrixType matrixU(const Solver & self) { return self.matrixU(); }
110-
static VectorType vectorD(const Solver & self) { return self.vectorD(); }
113+
static VectorXs vectorD(const Solver & self) { return self.vectorD(); }
111114

112115
static MatrixType transpositionsP(const Solver & self)
113116
{
114117
return self.transpositionsP() * MatrixType::Identity(self.matrixL().rows(),
115118
self.matrixL().rows());
116119
}
117120

118-
template<typename VectorType>
119-
static VectorType solve(const Solver & self, const VectorType & vec)
121+
template<typename MatrixOrVector>
122+
static MatrixOrVector solve(const Solver & self, const MatrixOrVector & vec)
120123
{
121124
return self.solve(vec);
122125
}

include/eigenpy/decompositions/LLT.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 INRIA
2+
* Copyright 2020-2021 INRIA
33
*/
44

55
#ifndef __eigenpy_decomposition_llt_hpp__
@@ -23,7 +23,8 @@ namespace eigenpy
2323
typedef _MatrixType MatrixType;
2424
typedef typename MatrixType::Scalar Scalar;
2525
typedef typename MatrixType::RealScalar RealScalar;
26-
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorType;
26+
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,1,MatrixType::Options> VectorXs;
27+
typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic,MatrixType::Options> MatrixXs;
2728
typedef Eigen::LLT<MatrixType> Solver;
2829

2930
template<class PyClass>
@@ -46,10 +47,10 @@ namespace eigenpy
4647
bp::return_internal_reference<>())
4748

4849
#if EIGEN_VERSION_AT_LEAST(3,3,90)
49-
.def("rankUpdate",(Solver& (Solver::*)(const VectorType &, const RealScalar &))&Solver::template rankUpdate<VectorType>,
50+
.def("rankUpdate",(Solver& (Solver::*)(const VectorXs &, const RealScalar &))&Solver::template rankUpdate<VectorXs>,
5051
bp::args("self","vector","sigma"), bp::return_self<>())
5152
#else
52-
.def("rankUpdate",(Solver (Solver::*)(const VectorType &, const RealScalar &))&Solver::template rankUpdate<VectorType>,
53+
.def("rankUpdate",(Solver (Solver::*)(const VectorXs &, const RealScalar &))&Solver::template rankUpdate<VectorXs>,
5354
bp::args("self","vector","sigma"))
5455
#endif
5556

@@ -72,8 +73,10 @@ namespace eigenpy
7273
#endif
7374
.def("reconstructedMatrix",&Solver::reconstructedMatrix,bp::arg("self"),
7475
"Returns the matrix represented by the decomposition, i.e., it returns the product: L L^*. This function is provided for debug purpose.")
75-
.def("solve",&solve<VectorType>,bp::args("self","b"),
76+
.def("solve",&solve<VectorXs>,bp::args("self","b"),
7677
"Returns the solution x of A x = b using the current decomposition of A.")
78+
.def("solve",&solve<MatrixXs>,bp::args("self","B"),
79+
"Returns the solution X of A X = B using the current decomposition of A where B is a right hand side matrix.")
7780
;
7881
}
7982

@@ -99,8 +102,8 @@ namespace eigenpy
99102
static MatrixType matrixL(const Solver & self) { return self.matrixL(); }
100103
static MatrixType matrixU(const Solver & self) { return self.matrixU(); }
101104

102-
template<typename VectorType>
103-
static VectorType solve(const Solver & self, const VectorType & vec)
105+
template<typename MatrixOrVector>
106+
static MatrixOrVector solve(const Solver & self, const MatrixOrVector & vec)
104107
{
105108
return self.solve(vec);
106109
}

unittest/python/test_LDLT.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import eigenpy
2-
eigenpy.switchToNumpyArray()
32

43
import numpy as np
54
import numpy.linalg as la
@@ -16,3 +15,9 @@
1615
P = ldlt.transpositionsP()
1716

1817
assert eigenpy.is_approx(np.transpose(P).dot(L.dot(np.diag(D).dot(np.transpose(L).dot(P)))),A)
18+
19+
X = np.random.rand(dim,20)
20+
B = A.dot(X)
21+
X_est = ldlt.solve(B)
22+
assert eigenpy.is_approx(X,X_est)
23+
assert eigenpy.is_approx(A.dot(X_est),B)

unittest/python/test_LLT.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import eigenpy
2-
eigenpy.switchToNumpyArray()
32

43
import numpy as np
54
import numpy.linalg as la
@@ -12,5 +11,10 @@
1211
llt = eigenpy.LLT(A)
1312

1413
L = llt.matrixL()
15-
1614
assert eigenpy.is_approx(L.dot(np.transpose(L)),A)
15+
16+
X = np.random.rand(dim,20)
17+
B = A.dot(X)
18+
X_est = llt.solve(B)
19+
assert eigenpy.is_approx(X,X_est)
20+
assert eigenpy.is_approx(A.dot(X_est),B)

0 commit comments

Comments
 (0)