|
1 | 1 | //! Implement Operator norms for matrices
|
2 | 2 |
|
3 | 3 | use lapack::c;
|
| 4 | +use num_traits::Zero; |
4 | 5 |
|
5 | 6 | use types::*;
|
6 | 7 | use error::*;
|
7 | 8 | use layout::Layout;
|
8 | 9 |
|
9 | 10 | #[repr(u8)]
|
10 |
| -pub enum FlagSVD { |
| 11 | +enum FlagSVD { |
11 | 12 | All = b'A',
|
12 |
| - OverWrite = b'O', |
13 |
| - Separately = b'S', |
| 13 | + // OverWrite = b'O', |
| 14 | + // Separately = b'S', |
14 | 15 | No = b'N',
|
15 | 16 | }
|
16 | 17 |
|
17 |
| -pub trait SVD_: Sized { |
18 |
| - fn svd(Layout, u_flag: FlagSVD, v_flag: FlagSVD, a: &[Self]) -> Result<()>; |
| 18 | +pub struct SVDOutput<A: AssociatedReal> { |
| 19 | + pub s: Vec<A::Real>, |
| 20 | + pub u: Option<Vec<A>>, |
| 21 | + pub vt: Option<Vec<A>>, |
19 | 22 | }
|
| 23 | + |
| 24 | +pub trait SVD_: AssociatedReal { |
| 25 | + fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SVDOutput<Self>>; |
| 26 | +} |
| 27 | + |
| 28 | +macro_rules! impl_svd { |
| 29 | + ($scalar:ty, $gesvd:path) => { |
| 30 | + |
| 31 | +impl SVD_ for $scalar { |
| 32 | + fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> { |
| 33 | + let (n, m) = l.size(); |
| 34 | + let k = ::std::cmp::min(n, m); |
| 35 | + let lda = l.lda(); |
| 36 | + let (ju, ldu, mut u) = if calc_u { |
| 37 | + (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize]) |
| 38 | + } else { |
| 39 | + (FlagSVD::No, 0, Vec::new()) |
| 40 | + }; |
| 41 | + let (jvt, ldvt, mut vt) = if calc_vt { |
| 42 | + (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize]) |
| 43 | + } else { |
| 44 | + (FlagSVD::No, 0, Vec::new()) |
| 45 | + }; |
| 46 | + let mut s = vec![Self::Real::zero(); k as usize]; |
| 47 | + let mut superb = vec![Self::Real::zero(); (k-2) as usize]; |
| 48 | + let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb); |
| 49 | + if info == 0 { |
| 50 | + Ok(SVDOutput { |
| 51 | + s: s, |
| 52 | + u: if ldu > 0 { Some(u) } else { None }, |
| 53 | + vt: if ldvt > 0 { Some(vt) } else { None }, |
| 54 | + }) |
| 55 | + } else { |
| 56 | + Err(LapackError::new(info).into()) |
| 57 | + } |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +}} // impl_svd! |
| 62 | + |
| 63 | +impl_svd!(f64, c::dgesvd); |
| 64 | +impl_svd!(f32, c::sgesvd); |
| 65 | +impl_svd!(c64, c::zgesvd); |
| 66 | +impl_svd!(c32, c::cgesvd); |
0 commit comments