Skip to content

Commit 133de4a

Browse files
committed
Use layout at SVD
1 parent 31da57b commit 133de4a

File tree

2 files changed

+23
-59
lines changed

2 files changed

+23
-59
lines changed

src/matrix.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,15 @@ impl<A> Matrix for Array<A, Ix2>
8989
ImplNorm::norm_f(m, n, self.clone().into_raw_vec())
9090
}
9191
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
92-
let strides = self.strides();
93-
let (m, n) = if strides[0] > strides[1] {
94-
self.size()
95-
} else {
96-
let (n, m) = self.size();
97-
(m, n)
98-
};
99-
let (u, s, vt) = try!(ImplSVD::svd(m, n, self.clone().into_raw_vec()));
92+
let (n, m) = self.size();
93+
let layout = self.layout()?;
94+
let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?;
10095
let sv = Array::from_vec(s);
101-
if strides[0] > strides[1] {
102-
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
103-
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
104-
Ok((va, sv, ua))
105-
} else {
106-
let ua = Array::from_vec(u).into_shape((n, n)).unwrap().reversed_axes();
107-
let va = Array::from_vec(vt).into_shape((m, m)).unwrap().reversed_axes();
108-
Ok((ua, sv, va))
96+
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
97+
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
98+
match layout {
99+
Layout::RowMajor => Ok((ua, sv, va)),
100+
Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
109101
}
110102
}
111103
fn qr(self) -> Result<(Self, Self), LinalgError> {

src/svd.rs

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,37 @@
11
//! Implement SVD
22
3-
use lapack::fortran::*;
3+
use std::cmp::min;
4+
use lapack::c::*;
45
use num_traits::Zero;
56

67
use error::LapackError;
78

89
pub trait ImplSVD: Sized {
9-
fn svd(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError>;
10+
fn svd(layout: Layout,
11+
n: usize,
12+
m: usize,
13+
mut a: Vec<Self>)
14+
-> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError>;
1015
}
1116

1217
macro_rules! impl_svd {
1318
($scalar:ty, $gesvd:path) => {
1419
impl ImplSVD for $scalar {
15-
fn svd(n: usize,
16-
m: usize,
17-
mut a: Vec<Self>)
18-
-> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError> {
19-
let mut info = 0;
20+
fn svd(layout: Layout, n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError> {
21+
let k = min(n, m);
2022
let n = n as i32;
2123
let m = m as i32;
22-
let lda = m;
24+
let lda = match layout {
25+
Layout::RowMajor => n,
26+
Layout::ColumnMajor => m,
27+
};
2328
let ldu = m;
2429
let ldvt = n;
25-
let lwork = -1;
26-
let lw_default = 1000;
2730
let mut u = vec![Self::zero(); (ldu * m) as usize];
2831
let mut vt = vec![Self::zero(); (ldvt * n) as usize];
2932
let mut s = vec![Self::zero(); n as usize];
30-
let mut work = vec![Self::zero(); lw_default];
31-
$gesvd('A' as u8,
32-
'A' as u8,
33-
m,
34-
n,
35-
&mut a,
36-
lda,
37-
&mut s,
38-
&mut u,
39-
ldu,
40-
&mut vt,
41-
ldvt,
42-
&mut work,
43-
lwork,
44-
&mut info); // calc optimal work
45-
let lwork = work[0] as i32;
46-
if lwork > lw_default as i32 {
47-
work = vec![Self::zero(); lwork as usize];
48-
}
49-
$gesvd('A' as u8,
50-
'A' as u8,
51-
m,
52-
n,
53-
&mut a,
54-
lda,
55-
&mut s,
56-
&mut u,
57-
ldu,
58-
&mut vt,
59-
ldvt,
60-
&mut work,
61-
lwork,
62-
&mut info);
33+
let mut superb = vec![Self::zero(); k-2];
34+
let info = $gesvd(layout, 'A' as u8, 'A' as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb);
6335
if info == 0 {
6436
Ok((u, s, vt))
6537
} else {

0 commit comments

Comments
 (0)