Skip to content

Commit fd73d09

Browse files
authored
Fix for #864 - Adding tolerance to pseudo inverse matrix (#865)
1 parent 1189911 commit fd73d09

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

math/src/main/codegen/breeze/linalg/pinv.scala

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,40 @@ import breeze.macros.expand
1414
* Solving A^T (AX-B) = 0 for X yields
1515
* A^T AX = A^T B
1616
* => X = (A^T A)^(-1) A^T B
17+
*
18+
* @param v: Matrix to be pseudo-inverted
19+
* @param rcond: Cutoff for small singular values. Singular values less than or equal to rcond * largest_singular_value
20+
* are set to zero. Default: 1e-15. To deactivate this option, set rcond to zero.
1721
*/
1822
object pinv extends UFunc with pinvLowPrio {
1923

24+
private val DEFAULT_RCOND = 1e-15f
25+
2026
@expand
2127
@expand.valify
2228
implicit def pinvFromSVD[@expand.args(Float, Double) T]: Impl[DenseMatrix[T], DenseMatrix[T]] = {
2329
new Impl[DenseMatrix[T], DenseMatrix[T]] {
24-
// http://en.wikipedia.org/wiki/Singular_value_decomposition#Applications_of_the_SVD
30+
// No rcond passed as parameter, use default value
2531
override def apply(v: DenseMatrix[T]): DenseMatrix[T] = {
32+
val rcond: T = DEFAULT_RCOND
33+
pinv(v, rcond)
34+
}
35+
}
36+
}
37+
38+
39+
@expand
40+
@expand.valify
41+
implicit def pinvFromSVDRcond[@expand.args(Float, Double) T]: Impl2[DenseMatrix[T], T, DenseMatrix[T]] = {
42+
new Impl2[DenseMatrix[T], T, DenseMatrix[T]] {
43+
// http://en.wikipedia.org/wiki/Singular_value_decomposition#Applications_of_the_SVD
44+
override def apply(v: DenseMatrix[T], rcond: T): DenseMatrix[T] = {
45+
require(rcond >= 0, "rcond must be non-negative")
46+
2647
val svd.SVD(s, svs, d) = svd(v)
48+
val cutoff = max(svs) * rcond
2749
val vi = svs.map { v =>
28-
if (v == 0.0) 0 else 1 / v
50+
if (v <= cutoff) 0 else 1 / v
2951
}
3052

3153
val svDiag = DenseMatrix.tabulate[T](s.cols, d.rows) { (i, j) =>
@@ -58,12 +80,12 @@ trait pinvLowPrio { self: pinv.type =>
5880
* @return
5981
*/
6082
implicit def implFromTransposeAndSolve[T, TransT, MulRes, Result](
61-
implicit numericT: T => NumericOps[T],
62-
trans: CanTranspose[T, TransT],
63-
numericTrans: TransT => NumericOps[TransT],
64-
mul: OpMulMatrix.Impl2[TransT, T, MulRes],
65-
numericMulRes: MulRes => NumericOps[MulRes],
66-
solve: OpSolveMatrixBy.Impl2[MulRes, TransT, Result]): Impl[T, Result] = {
83+
implicit numericT: T => NumericOps[T],
84+
trans: CanTranspose[T, TransT],
85+
numericTrans: TransT => NumericOps[TransT],
86+
mul: OpMulMatrix.Impl2[TransT, T, MulRes],
87+
numericMulRes: MulRes => NumericOps[MulRes],
88+
solve: OpSolveMatrixBy.Impl2[MulRes, TransT, Result]): Impl[T, Result] = {
6789
new Impl[T, Result] {
6890
def apply(X: T): Result = {
6991
(X.t * X) \ X.t

math/src/test/scala/breeze/linalg/LinearAlgebraTest.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,27 @@ class LinearAlgebraTest extends AnyFunSuite with Checkers with DoubleImplicits {
159159
matricesNearlyEqual(mi * m, eye)
160160
}
161161

162+
test("pinv test singular values cutoff doubles: #864") {
163+
val m = DenseMatrix((0.1d, 0.1d, 0d), (0.1d, 0.1d, 0d), (0d, 0d, 0d))
164+
val mi = pinv(m)
165+
val expected = DenseMatrix((2.5d, 2.5d, 0d), (2.5d, 2.5d, 0d), (0d, 0d, 0d))
166+
matricesNearlyEqual(mi, expected)
167+
}
168+
169+
test("pinv test singular values cutoff floats: #864") {
170+
// Cutoff needs to be to a more strict value than for Doubles
171+
val m = DenseMatrix((0.1f, 0.1f, 0f), (0.1f, 0.1f, 0f), (0f, 0f, 0f))
172+
val mi = pinv(m, 1e-7f)
173+
val expected = DenseMatrix((2.5f, 2.5f, 0f), (2.5f, 2.5f, 0f), (0f, 0f, 0f))
174+
matricesNearlyEqual_Float(mi, expected)
175+
}
176+
177+
test("pinv test bad conditioned matrix, no cutoff: #864") {
178+
val m = DenseMatrix((0.1d, 0.1d, 0d), (0.1d, 0.1d, 0d), (0d, 0d, 0d))
179+
val mi = pinv(m, 0.0)
180+
mi(0,0) should be > 1e10
181+
}
182+
162183
test("cross") {
163184
// specific example; with prime elements
164185
val (v1, v2, r) = (DenseVector(13, 3, 7), DenseVector(5, 11, 17), DenseVector(-26, -186, 128))

0 commit comments

Comments
 (0)