Skip to content

Commit 4e65a47

Browse files
committed
impl SVD_
1 parent ddc3bb9 commit 4e65a47

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

src/impl2/svd.rs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,66 @@
11
//! Implement Operator norms for matrices
22
33
use lapack::c;
4+
use num_traits::Zero;
45

56
use types::*;
67
use error::*;
78
use layout::Layout;
89

910
#[repr(u8)]
10-
pub enum FlagSVD {
11+
enum FlagSVD {
1112
All = b'A',
12-
OverWrite = b'O',
13-
Separately = b'S',
13+
// OverWrite = b'O',
14+
// Separately = b'S',
1415
No = b'N',
1516
}
1617

17-
pub trait SVD_: Sized {
18-
fn svd(Layout, u_flag: FlagSVD, v_flag: FlagSVD, a: &[Self]) -> Result<()>;
18+
pub struct SVDOutput<A: AssociatedReal> {
19+
pub s: Vec<A::Real>,
20+
pub u: Option<Vec<A>>,
21+
pub vt: Option<Vec<A>>,
1922
}
23+
24+
pub trait SVD_: AssociatedReal {
25+
fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SVDOutput<Self>>;
26+
}
27+
28+
macro_rules! impl_svd {
29+
($scalar:ty, $gesvd:path) => {
30+
31+
impl SVD_ for $scalar {
32+
fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
33+
let (n, m) = l.size();
34+
let k = ::std::cmp::min(n, m);
35+
let lda = l.lda();
36+
let (ju, ldu, mut u) = if calc_u {
37+
(FlagSVD::All, m, vec![Self::zero(); (m*m) as usize])
38+
} else {
39+
(FlagSVD::No, 0, Vec::new())
40+
};
41+
let (jvt, ldvt, mut vt) = if calc_vt {
42+
(FlagSVD::All, m, vec![Self::zero(); (m*m) as usize])
43+
} else {
44+
(FlagSVD::No, 0, Vec::new())
45+
};
46+
let mut s = vec![Self::Real::zero(); k as usize];
47+
let mut superb = vec![Self::Real::zero(); (k-2) as usize];
48+
let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb);
49+
if info == 0 {
50+
Ok(SVDOutput {
51+
s: s,
52+
u: if ldu > 0 { Some(u) } else { None },
53+
vt: if ldvt > 0 { Some(vt) } else { None },
54+
})
55+
} else {
56+
Err(LapackError::new(info).into())
57+
}
58+
}
59+
}
60+
61+
}} // impl_svd!
62+
63+
impl_svd!(f64, c::dgesvd);
64+
impl_svd!(f32, c::sgesvd);
65+
impl_svd!(c64, c::zgesvd);
66+
impl_svd!(c32, c::cgesvd);

0 commit comments

Comments
 (0)