Skip to content

Commit 2be2c42

Browse files
authored
Merge pull request #33 from termoshtt/decompositions
Decompositions
2 parents 3e18871 + 0e2d9a2 commit 2be2c42

File tree

13 files changed

+342
-72
lines changed

13 files changed

+342
-72
lines changed

src/impl2/mod.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11

22
pub mod opnorm;
3+
pub mod qr;
4+
pub mod svd;
5+
36
pub use self::opnorm::*;
7+
pub use self::qr::*;
8+
pub use self::svd::*;
9+
10+
use super::error::*;
11+
12+
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
13+
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}
414

5-
pub trait LapackScalar: OperatorNorm_ {}
6-
impl<A> LapackScalar for A where A: OperatorNorm_ {}
15+
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
16+
if info == 0 {
17+
Ok(val)
18+
} else {
19+
Err(LapackError::new(info).into())
20+
}
21+
}

src/impl2/opnorm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use lapack::c;
44
use lapack::c::Layout::ColumnMajor as cm;
55

66
use types::*;
7-
use layout::*;
7+
use layout::Layout;
88

99
#[repr(u8)]
1010
pub enum NormType {

src/impl2/qr.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//! Implement QR decomposition
2+
3+
use std::cmp::min;
4+
use num_traits::Zero;
5+
use lapack::c;
6+
7+
use types::*;
8+
use error::*;
9+
use layout::Layout;
10+
11+
use super::into_result;
12+
13+
pub trait QR_: Sized {
14+
fn householder(Layout, a: &mut [Self]) -> Result<Vec<Self>>;
15+
fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>;
16+
fn qr(Layout, a: &mut [Self]) -> Result<Vec<Self>>;
17+
}
18+
19+
macro_rules! impl_qr {
20+
($scalar:ty, $qrf:path, $gqr:path) => {
21+
impl QR_ for $scalar {
22+
fn householder(l: Layout, mut a: &mut [Self]) -> Result<Vec<Self>> {
23+
let (row, col) = l.size();
24+
let k = min(row, col);
25+
let mut tau = vec![Self::zero(); k as usize];
26+
let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau);
27+
into_result(info, tau)
28+
}
29+
30+
fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
31+
let (row, col) = l.size();
32+
let k = min(row, col);
33+
let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau);
34+
into_result(info, ())
35+
}
36+
37+
fn qr(l: Layout, mut a: &mut [Self]) -> Result<Vec<Self>> {
38+
let tau = Self::householder(l, a)?;
39+
let r = Vec::from(&*a);
40+
Self::q(l, a, &tau)?;
41+
Ok(r)
42+
}
43+
}
44+
}} // endmacro
45+
46+
impl_qr!(f64, c::dgeqrf, c::dorgqr);
47+
impl_qr!(f32, c::sgeqrf, c::sorgqr);
48+
impl_qr!(c64, c::zgeqrf, c::zungqr);
49+
impl_qr!(c32, c::cgeqrf, c::cungqr);

src/impl2/svd.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//! Implement Operator norms for matrices
2+
3+
use lapack::c;
4+
use num_traits::Zero;
5+
6+
use types::*;
7+
use error::*;
8+
use layout::Layout;
9+
10+
use super::into_result;
11+
12+
#[repr(u8)]
13+
enum FlagSVD {
14+
All = b'A',
15+
// OverWrite = b'O',
16+
// Separately = b'S',
17+
No = b'N',
18+
}
19+
20+
pub struct SVDOutput<A: AssociatedReal> {
21+
pub s: Vec<A::Real>,
22+
pub u: Option<Vec<A>>,
23+
pub vt: Option<Vec<A>>,
24+
}
25+
26+
pub trait SVD_: AssociatedReal {
27+
fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SVDOutput<Self>>;
28+
}
29+
30+
macro_rules! impl_svd {
31+
($scalar:ty, $gesvd:path) => {
32+
33+
impl SVD_ for $scalar {
34+
fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
35+
let (m, n) = l.size();
36+
let k = ::std::cmp::min(n, m);
37+
let lda = l.lda();
38+
let (ju, ldu, mut u) = if calc_u {
39+
(FlagSVD::All, m, vec![Self::zero(); (m*m) as usize])
40+
} else {
41+
(FlagSVD::No, 0, Vec::new())
42+
};
43+
let (jvt, ldvt, mut vt) = if calc_vt {
44+
(FlagSVD::All, n, vec![Self::zero(); (n*n) as usize])
45+
} else {
46+
(FlagSVD::No, 0, Vec::new())
47+
};
48+
let mut s = vec![Self::Real::zero(); k as usize];
49+
let mut superb = vec![Self::Real::zero(); (k-2) as usize];
50+
let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb);
51+
into_result(info, SVDOutput {
52+
s: s,
53+
u: if ldu > 0 { Some(u) } else { None },
54+
vt: if ldvt > 0 { Some(vt) } else { None },
55+
})
56+
}
57+
}
58+
59+
}} // impl_svd!
60+
61+
impl_svd!(f64, c::dgesvd);
62+
impl_svd!(f32, c::sgesvd);
63+
impl_svd!(c64, c::zgesvd);
64+
impl_svd!(c32, c::cgesvd);

src/layout.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11

22
use ndarray::*;
3+
use lapack::c;
34

45
use super::error::*;
56

67
pub type LDA = i32;
78
pub type Col = i32;
89
pub type Row = i32;
910

11+
#[derive(Debug, Clone, Copy)]
1012
pub enum Layout {
1113
C((Row, LDA)),
1214
F((Col, LDA)),
@@ -19,6 +21,27 @@ impl Layout {
1921
Layout::F((col, lda)) => (lda, col),
2022
}
2123
}
24+
25+
pub fn resized(&self, row: Row, col: Col) -> Layout {
26+
match *self {
27+
Layout::C(_) => Layout::C((row, col)),
28+
Layout::F(_) => Layout::F((col, row)),
29+
}
30+
}
31+
32+
pub fn lda(&self) -> LDA {
33+
match *self {
34+
Layout::C((_, lda)) => lda,
35+
Layout::F((_, lda)) => lda,
36+
}
37+
}
38+
39+
pub fn lapacke_layout(&self) -> c::Layout {
40+
match *self {
41+
Layout::C(_) => c::Layout::RowMajor,
42+
Layout::F(_) => c::Layout::ColumnMajor,
43+
}
44+
}
2245
}
2346

2447
pub trait AllocatedArray {
@@ -28,6 +51,10 @@ pub trait AllocatedArray {
2851
fn as_allocated(&self) -> Result<&[Self::Scalar]>;
2952
}
3053

54+
pub trait AllocatedArrayMut: AllocatedArray {
55+
fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>;
56+
}
57+
3158
impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
3259
where S: Data<Elem = A>
3360
{
@@ -60,3 +87,21 @@ impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
6087
Ok(slice)
6188
}
6289
}
90+
91+
impl<A, S> AllocatedArrayMut for ArrayBase<S, Ix2>
92+
where S: DataMut<Elem = A>
93+
{
94+
fn as_allocated_mut(&mut self) -> Result<&mut [A]> {
95+
let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?;
96+
Ok(slice)
97+
}
98+
}
99+
100+
pub fn reconstruct<A, S>(l: Layout, a: Vec<A>) -> Result<ArrayBase<S, Ix2>>
101+
where S: DataOwned<Elem = A>
102+
{
103+
Ok(match l {
104+
Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?,
105+
Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?,
106+
})
107+
}

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ pub mod layout;
4848
pub mod impls;
4949
pub mod impl2;
5050

51-
pub mod traits;
51+
pub mod qr;
52+
pub mod svd;
53+
pub mod opnorm;
5254

5355
pub mod vector;
5456
pub mod matrix;

src/matrix.rs

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ use ndarray::DataMut;
66
use lapack::c::Layout;
77

88
use super::error::{LinalgError, StrideError};
9-
use super::impls::qr::ImplQR;
109
use super::impls::svd::ImplSVD;
1110
use super::impls::solve::ImplSolve;
1211

13-
pub trait MFloat: ImplQR + ImplSVD + ImplSolve + NdFloat {}
14-
impl<A: ImplQR + ImplSVD + ImplSolve + NdFloat> MFloat for A {}
12+
pub trait MFloat: ImplSVD + ImplSolve + NdFloat {}
13+
impl<A: ImplSVD + ImplSolve + NdFloat> MFloat for A {}
1514

1615
/// Methods for general matrices
1716
pub trait Matrix: Sized {
@@ -22,10 +21,6 @@ pub trait Matrix: Sized {
2221
fn size(&self) -> (usize, usize);
2322
/// Layout (C/Fortran) of matrix
2423
fn layout(&self) -> Result<Layout, StrideError>;
25-
/// singular-value decomposition (SVD)
26-
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>;
27-
/// QR decomposition
28-
fn qr(self) -> Result<(Self, Self), LinalgError>;
2924
/// LU decomposition
3025
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>;
3126
/// permutate matrix (inplace)
@@ -77,49 +72,6 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
7772
fn layout(&self) -> Result<Layout, StrideError> {
7873
check_layout(self.strides())
7974
}
80-
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
81-
let (n, m) = self.size();
82-
let layout = self.layout()?;
83-
let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?;
84-
let sv = Array::from_vec(s);
85-
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
86-
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
87-
match layout {
88-
Layout::RowMajor => Ok((ua, sv, va)),
89-
Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
90-
}
91-
}
92-
fn qr(self) -> Result<(Self, Self), LinalgError> {
93-
let (n, m) = self.size();
94-
let strides = self.strides();
95-
let k = min(n, m);
96-
let layout = self.layout()?;
97-
let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?;
98-
let (qa, ra) = if strides[0] < strides[1] {
99-
(Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(),
100-
Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes())
101-
} else {
102-
(Array::from_vec(q).into_shape((n, m)).unwrap(), Array::from_vec(r).into_shape((n, m)).unwrap())
103-
};
104-
let qm = if m > k {
105-
let (qsl, _) = qa.view().split_at(Axis(1), k);
106-
qsl.to_owned()
107-
} else {
108-
qa
109-
};
110-
let mut rm = if n > k {
111-
let (rsl, _) = ra.view().split_at(Axis(0), k);
112-
rsl.to_owned()
113-
} else {
114-
ra
115-
};
116-
for ((i, j), val) in rm.indexed_iter_mut() {
117-
if i > j {
118-
*val = A::zero();
119-
}
120-
}
121-
Ok((qm, rm))
122-
}
12375
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
12476
let (n, m) = self.size();
12577
let k = min(n, m);
@@ -163,14 +115,6 @@ impl<A: MFloat> Matrix for RcArray<A, Ix2> {
163115
fn layout(&self) -> Result<Layout, StrideError> {
164116
check_layout(self.strides())
165117
}
166-
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
167-
let (u, s, v) = self.into_owned().svd()?;
168-
Ok((u.into_shared(), s.into_shared(), v.into_shared()))
169-
}
170-
fn qr(self) -> Result<(Self, Self), LinalgError> {
171-
let (q, r) = self.into_owned().qr()?;
172-
Ok((q.into_shared(), r.into_shared()))
173-
}
174118
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
175119
let (p, l, u) = self.into_owned().lu()?;
176120
Ok((p, l.into_shared(), u.into_shared()))

src/traits.rs renamed to src/opnorm.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11

2-
pub use impl2::LapackScalar;
3-
pub use impl2::NormType;
4-
52
use ndarray::*;
63

74
use super::types::*;
85
use super::error::*;
96
use super::layout::*;
107

8+
pub use impl2::NormType;
9+
use impl2::LapackScalar;
10+
1111
pub trait OperationNorm {
1212
type Output;
1313
fn opnorm(&self, t: NormType) -> Self::Output;

src/prelude.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ pub use hermite::HermiteMatrix;
55
pub use triangular::*;
66
pub use util::*;
77
pub use assert::*;
8-
pub use traits::*;
8+
9+
pub use qr::*;
10+
pub use svd::*;
11+
pub use opnorm::*;

0 commit comments

Comments
 (0)