Skip to content

Commit a07846b

Browse files
committed
Replace SVD
1 parent ec4ff5a commit a07846b

File tree

7 files changed

+53
-27
lines changed

7 files changed

+53
-27
lines changed

src/impl2/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ pub use self::opnorm::*;
77
pub use self::qr::*;
88
pub use self::svd::*;
99

10-
pub trait LapackScalar: OperatorNorm_ + QR_ {}
11-
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ {}
10+
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
11+
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}

src/impl2/svd.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ macro_rules! impl_svd {
3030

3131
impl SVD_ for $scalar {
3232
fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
33-
let (n, m) = l.size();
33+
let (m, n) = l.size();
3434
let k = ::std::cmp::min(n, m);
3535
let lda = l.lda();
3636
let (ju, ldu, mut u) = if calc_u {
@@ -39,7 +39,7 @@ impl SVD_ for $scalar {
3939
(FlagSVD::No, 0, Vec::new())
4040
};
4141
let (jvt, ldvt, mut vt) = if calc_vt {
42-
(FlagSVD::All, m, vec![Self::zero(); (m*m) as usize])
42+
(FlagSVD::All, n, vec![Self::zero(); (n*n) as usize])
4343
} else {
4444
(FlagSVD::No, 0, Vec::new())
4545
};

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub mod impls;
4949
pub mod impl2;
5050

5151
pub mod qr;
52+
pub mod svd;
5253
pub mod opnorm;
5354

5455
pub mod vector;

src/matrix.rs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ pub trait Matrix: Sized {
2121
fn size(&self) -> (usize, usize);
2222
/// Layout (C/Fortran) of matrix
2323
fn layout(&self) -> Result<Layout, StrideError>;
24-
/// singular-value decomposition (SVD)
25-
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>;
2624
/// LU decomposition
2725
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>;
2826
/// permutate matrix (inplace)
@@ -74,18 +72,6 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
7472
fn layout(&self) -> Result<Layout, StrideError> {
7573
check_layout(self.strides())
7674
}
77-
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
78-
let (n, m) = self.size();
79-
let layout = self.layout()?;
80-
let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?;
81-
let sv = Array::from_vec(s);
82-
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
83-
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
84-
match layout {
85-
Layout::RowMajor => Ok((ua, sv, va)),
86-
Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
87-
}
88-
}
8975
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
9076
let (n, m) = self.size();
9177
let k = min(n, m);
@@ -129,10 +115,6 @@ impl<A: MFloat> Matrix for RcArray<A, Ix2> {
129115
fn layout(&self) -> Result<Layout, StrideError> {
130116
check_layout(self.strides())
131117
}
132-
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
133-
let (u, s, v) = self.into_owned().svd()?;
134-
Ok((u.into_shared(), s.into_shared(), v.into_shared()))
135-
}
136118
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
137119
let (p, l, u) = self.into_owned().lu()?;
138120
Ok((p, l.into_shared(), u.into_shared()))

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ pub use util::*;
77
pub use assert::*;
88

99
pub use qr::*;
10+
pub use svd::*;
1011
pub use opnorm::*;

src/svd.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
use ndarray::*;
3+
4+
use super::error::*;
5+
use super::layout::{Layout, AllocatedArray, AllocatedArrayMut};
6+
use impl2::LapackScalar;
7+
8+
pub trait SVD<U, S, VT> {
9+
fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option<U>, S, Option<VT>)>;
10+
}
11+
12+
impl<A, S, Su, Svt, Ss> SVD<ArrayBase<Su, Ix2>, ArrayBase<Ss, Ix1>, ArrayBase<Svt, Ix2>> for ArrayBase<S, Ix2>
13+
where A: LapackScalar,
14+
S: DataMut<Elem = A>,
15+
Su: DataOwned<Elem = A>,
16+
Svt: DataOwned<Elem = A>,
17+
Ss: DataOwned<Elem = A::Real>
18+
{
19+
fn svd(mut self,
20+
calc_u: bool,
21+
calc_vt: bool)
22+
-> Result<(Option<ArrayBase<Su, Ix2>>, ArrayBase<Ss, Ix1>, Option<ArrayBase<Svt, Ix2>>)> {
23+
let n = self.rows();
24+
let m = self.cols();
25+
let l = self.layout()?;
26+
let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?;
27+
let (u, vt) = match l {
28+
Layout::C(_) => {
29+
(svd_res.u.map(|u| ArrayBase::from_shape_vec((n, n), u).unwrap()),
30+
svd_res.vt.map(|vt| ArrayBase::from_shape_vec((m, m), vt).unwrap()))
31+
}
32+
Layout::F(_) => {
33+
(svd_res.u.map(|u| ArrayBase::from_shape_vec((n, n).f(), u).unwrap()),
34+
svd_res.vt.map(|vt| ArrayBase::from_shape_vec((m, m).f(), vt).unwrap()))
35+
}
36+
};
37+
let s = ArrayBase::from_vec(svd_res.s);
38+
Ok((u, s, vt))
39+
}
40+
}

tests/svd.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ fn $funcname() {
99
use ndarray_linalg::prelude::*;
1010
let a = $random($n, $m, $t);
1111
let answer = a.clone();
12-
println!("a = \n{}", &a);
13-
let (u, s, vt) = a.svd().unwrap();
14-
println!("u = \n{}", &u);
15-
println!("s = \n{}", &s);
16-
println!("v = \n{}", &vt);
12+
println!("a = \n{:?}", &a);
13+
let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap();
14+
let u: Array2<_> = u.unwrap();
15+
let vt: Array2<_> = vt.unwrap();
16+
println!("u = \n{:?}", &u);
17+
println!("s = \n{:?}", &s);
18+
println!("v = \n{:?}", &vt);
1719
let mut sm = Array::zeros(($n, $m));
1820
for i in 0..min($n, $m) {
1921
sm[(i, i)] = s[i];

0 commit comments

Comments
 (0)