Skip to content

Commit eb2ad98

Browse files
committed
Use layout in QR factorization
1 parent 133de4a commit eb2ad98

File tree

2 files changed

+29
-102
lines changed

2 files changed

+29
-102
lines changed

src/matrix.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,16 @@ impl<A> Matrix for Array<A, Ix2>
9696
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
9797
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
9898
match layout {
99-
Layout::RowMajor => Ok((ua, sv, va)),
99+
Layout::RowMajor => Ok((ua, sv, va)),
100100
Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
101101
}
102102
}
103103
fn qr(self) -> Result<(Self, Self), LinalgError> {
104104
let (n, m) = self.size();
105105
let strides = self.strides();
106106
let k = min(n, m);
107-
let (q, r) = if strides[0] < strides[1] {
108-
try!(ImplQR::qr(m, n, self.clone().into_raw_vec()))
109-
} else {
110-
try!(ImplQR::lq(n, m, self.clone().into_raw_vec()))
111-
};
107+
let layout = self.layout()?;
108+
let (q, r) = try!(ImplQR::qr(layout, m, n, self.clone().into_raw_vec()));
112109
let (qa, ra) = if strides[0] < strides[1] {
113110
(Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(),
114111
Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes())

src/qr.rs

Lines changed: 26 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,41 @@
11
//! Implement QR decomposition
22
33
use std::cmp::min;
4-
use lapack::fortran::*;
4+
use lapack::c::*;
55
use num_traits::Zero;
66

77
use error::LapackError;
88

99
pub trait ImplQR: Sized {
10-
fn qr(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
11-
fn lq(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
10+
fn qr(layout: Layout, n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
1211
}
1312

1413
macro_rules! impl_qr {
15-
($geqrf:path, $orgqr:path, $gelqf:path, $orglq:path) => {
16-
// XXX These codes are most same, but the argument of $orgqr and $orglq are different!
17-
fn qr(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
18-
let n = n as i32;
19-
let m = m as i32;
20-
let mut info = 0;
21-
let k = min(m, n);
22-
let lda = m;
23-
let lw_default = 1000;
24-
let mut tau = vec![Self::zero(); k as usize];
25-
let mut work = vec![Self::zero(); lw_default];
26-
// estimate lwork
27-
$geqrf(m, n, &mut a, lda, &mut tau, &mut work, -1, &mut info);
28-
let lwork_r = work[0] as i32;
29-
if lwork_r > lw_default as i32 {
30-
work = vec![Self::zero(); lwork_r as usize];
31-
}
32-
// calc R
33-
$geqrf(m, n, &mut a, lda, &mut tau, &mut work, lwork_r, &mut info);
34-
if info != 0 {
35-
return Err(From::from(info));
36-
}
37-
let r = a.clone();
38-
// re-estimate lwork
39-
$orgqr(m, k, k, &mut a, lda, &mut tau, &mut work, -1, &mut info);
40-
let lwork_q = work[0] as i32;
41-
if lwork_q > lwork_r {
42-
work = vec![Self::zero(); lwork_q as usize];
43-
}
44-
// calc Q
45-
$orgqr(m,
46-
k,
47-
k,
48-
&mut a,
49-
lda,
50-
&mut tau,
51-
&mut work,
52-
lwork_q,
53-
&mut info);
54-
if info == 0 {
55-
Ok((a, r))
56-
} else {
57-
Err(From::from(info))
58-
}
59-
}
60-
fn lq(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
61-
let n = n as i32;
62-
let m = m as i32;
63-
let mut info = 0;
64-
let k = min(m, n);
65-
let lda = m;
66-
let lw_default = 1000;
67-
let mut tau = vec![Self::zero(); k as usize];
68-
let mut work = vec![Self::zero(); lw_default];
69-
// estimate lwork
70-
$gelqf(m, n, &mut a, lda, &mut tau, &mut work, -1, &mut info);
71-
let lwork_r = work[0] as i32;
72-
if lwork_r > lw_default as i32 {
73-
work = vec![Self::zero(); lwork_r as usize];
74-
}
75-
// calc R
76-
$gelqf(m, n, &mut a, lda, &mut tau, &mut work, lwork_r, &mut info);
77-
if info != 0 {
78-
return Err(From::from(info));
79-
}
80-
let r = a.clone();
81-
// re-estimate lwork
82-
$orglq(k, n, k, &mut a, lda, &mut tau, &mut work, -1, &mut info);
83-
let lwork_q = work[0] as i32;
84-
if lwork_q > lwork_r {
85-
work = vec![Self::zero(); lwork_q as usize];
86-
}
87-
// calc Q
88-
$orglq(k,
89-
n,
90-
k,
91-
&mut a,
92-
lda,
93-
&mut tau,
94-
&mut work,
95-
lwork_q,
96-
&mut info);
97-
if info == 0 {
98-
Ok((a, r))
99-
} else {
100-
Err(From::from(info))
14+
($scalar:ty, $geqrf:path, $orgqr:path) => {
15+
impl ImplQR for $scalar {
16+
fn qr(layout: Layout, n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
17+
let n = n as i32;
18+
let m = m as i32;
19+
let k = min(m, n);
20+
let lda = match layout {
21+
Layout::ColumnMajor => m,
22+
Layout::RowMajor => n,
23+
};
24+
let mut tau = vec![Self::zero(); k as usize];
25+
let info = $geqrf(layout, m, n, &mut a, lda, &mut tau);
26+
if info != 0 {
27+
return Err(From::from(info));
28+
}
29+
let r = a.clone();
30+
let info = $orgqr(layout, m, k, k, &mut a, lda, &mut tau);
31+
if info == 0 {
32+
Ok((a, r))
33+
} else {
34+
Err(From::from(info))
35+
}
10136
}
10237
}
10338
}} // endmacro
10439

105-
impl ImplQR for f64 {
106-
impl_qr!(dgeqrf, dorgqr, dgelqf, dorglq);
107-
}
108-
109-
impl ImplQR for f32 {
110-
impl_qr!(sgeqrf, sorgqr, sgelqf, sorglq);
111-
}
40+
impl_qr!(f64, dgeqrf, dorgqr);
41+
impl_qr!(f32, sgeqrf, sorgqr);

0 commit comments

Comments
 (0)