Skip to content

Commit a83df31

Browse files
committed
Take slice of QR
1 parent 65337f4 commit a83df31

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/traits.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
pub use impl2::LapackScalar;
33
pub use impl2::NormType;
44

5+
use num_traits::Zero;
56
use ndarray::*;
67

78
use super::types::*;
@@ -39,16 +40,32 @@ pub trait QR<Q, R> {
3940
fn qr2(self) -> Result<(Q, R)>;
4041
}
4142

42-
impl<A, Sq, Sr> QR<ArrayBase<Sq, Ix2>, ArrayBase<Sr, Ix2>> for ArrayBase<Sq, Ix2>
43-
where A: LapackScalar,
44-
Sq: DataMut<Elem = A>,
45-
Sr: DataOwned<Elem = A>
43+
impl<A, S, Sq, Sr> QR<ArrayBase<Sq, Ix2>, ArrayBase<Sr, Ix2>> for ArrayBase<S, Ix2>
44+
where A: LapackScalar + Copy + Zero,
45+
S: DataMut<Elem = A>,
46+
Sq: DataOwned<Elem = A> + DataMut,
47+
Sr: DataOwned<Elem = A> + DataMut
4648
{
4749
fn qr2(mut self) -> Result<(ArrayBase<Sq, Ix2>, ArrayBase<Sr, Ix2>)> {
50+
let n = self.rows();
51+
let m = self.cols();
52+
let k = ::std::cmp::min(n, m);
4853
let l = self.layout()?;
54+
// calc QR decomposition
4955
let r = A::qr(l, self.as_allocated_mut()?)?;
50-
let r = reconstruct(l, r)?;
56+
let r: Array2<_> = reconstruct(l, r)?;
5157
let q = self;
58+
// get slice
59+
let qv = q.slice(s![..n as isize, ..k as isize]);
60+
let mut q = unsafe { ArrayBase::uninitialized((n, k)) };
61+
q.assign(&qv);
62+
let rv = r.slice(s![..k as isize, ..m as isize]);
63+
let mut r = ArrayBase::zeros((k, m));
64+
for ((i, j), val) in r.indexed_iter_mut() {
65+
if i <= j {
66+
*val = rv[(i, j)];
67+
}
68+
}
5269
Ok((q, r))
5370
}
5471
}

tests/qr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ fn $funcname() {
1010
let a = $random($n, $m, $t);
1111
let ans = a.clone();
1212
println!("a = \n{:?}", &a);
13-
let (q, r) : (_, Array2<f64>) = a.qr2().unwrap();
13+
let (q, r) : (Array2<_>, Array2<_>) = a.qr2().unwrap();
1414
println!("q = \n{:?}", &q);
1515
println!("r = \n{:?}", &r);
1616
assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7);

0 commit comments

Comments
 (0)