Skip to content

Commit 51680fe

Browse files
committed
Implement Matrix for RcArray<A, Ix2>
1 parent 53cb90f commit 51680fe

File tree

1 file changed

+71
-22
lines changed

1 file changed

+71
-22
lines changed

src/matrix.rs

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::cmp::min;
44
use ndarray::prelude::*;
5+
use ndarray::DataMut;
56
use lapack::c::Layout;
67

78
use error::{LinalgError, StrideError};
@@ -43,6 +44,35 @@ pub trait Matrix: Sized {
4344
}
4445
}
4546

47+
fn check_layout(strides: &[Ixs]) -> Result<Layout, StrideError> {
48+
if min(strides[0], strides[1]) != 1 {
49+
return Err(StrideError {
50+
s0: strides[0],
51+
s1: strides[1],
52+
});;
53+
}
54+
if strides[0] < strides[1] {
55+
Ok(Layout::ColumnMajor)
56+
} else {
57+
Ok(Layout::RowMajor)
58+
}
59+
}
60+
61+
fn permutate<A: NdFloat, S>(mut a: &mut ArrayBase<S, Ix2>, ipiv: &Vec<i32>)
62+
where S: DataMut<Elem = A>
63+
{
64+
let m = a.cols();
65+
for (i, j_) in ipiv.iter().enumerate().rev() {
66+
let j = (j_ - 1) as usize;
67+
if i == j {
68+
continue;
69+
}
70+
for k in 0..m {
71+
a.swap((i, k), (j, k));
72+
}
73+
}
74+
}
75+
4676
impl<A: MFloat> Matrix for Array<A, Ix2> {
4777
type Scalar = A;
4878
type Vector = Array<A, Ix1>;
@@ -52,18 +82,7 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
5282
(self.rows(), self.cols())
5383
}
5484
fn layout(&self) -> Result<Layout, StrideError> {
55-
let strides = self.strides();
56-
if min(strides[0], strides[1]) != 1 {
57-
return Err(StrideError {
58-
s0: strides[0],
59-
s1: strides[1],
60-
});;
61-
}
62-
if strides[0] < strides[1] {
63-
Ok(Layout::ColumnMajor)
64-
} else {
65-
Ok(Layout::RowMajor)
66-
}
85+
check_layout(self.strides())
6786
}
6887
fn norm_1(&self) -> Self::Scalar {
6988
let (m, n) = self.size();
@@ -159,15 +178,45 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
159178
Ok((p, lm, am))
160179
}
161180
fn permutate(&mut self, ipiv: &Self::Permutator) {
162-
let (_, m) = self.size();
163-
for (i, j_) in ipiv.iter().enumerate().rev() {
164-
let j = (j_ - 1) as usize;
165-
if i == j {
166-
continue;
167-
}
168-
for k in 0..m {
169-
self.swap((i, k), (j, k));
170-
}
171-
}
181+
permutate(self, ipiv);
182+
}
183+
}
184+
185+
impl<A: MFloat> Matrix for RcArray<A, Ix2> {
186+
type Scalar = A;
187+
type Vector = RcArray<A, Ix1>;
188+
type Permutator = Vec<i32>;
189+
fn size(&self) -> (usize, usize) {
190+
(self.rows(), self.cols())
191+
}
192+
fn layout(&self) -> Result<Layout, StrideError> {
193+
check_layout(self.strides())
194+
}
195+
fn norm_1(&self) -> Self::Scalar {
196+
// TODO remove clone by into_owned()
197+
self.to_owned().norm_1()
198+
}
199+
fn norm_i(&self) -> Self::Scalar {
200+
// TODO remove clone by into_owned()
201+
self.to_owned().norm_i()
202+
}
203+
fn norm_f(&self) -> Self::Scalar {
204+
// TODO remove clone by into_owned()
205+
self.to_owned().norm_f()
206+
}
207+
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
208+
let (u, s, v) = self.to_owned().svd()?;
209+
Ok((u.into_shared(), s.into_shared(), v.into_shared()))
210+
}
211+
fn qr(self) -> Result<(Self, Self), LinalgError> {
212+
let (q, r) = self.to_owned().qr()?;
213+
Ok((q.into_shared(), r.into_shared()))
214+
}
215+
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
216+
let (p, l, u) = self.to_owned().lu()?;
217+
Ok((p, l.into_shared(), u.into_shared()))
218+
}
219+
fn permutate(&mut self, ipiv: &Self::Permutator) {
220+
permutate(self, ipiv);
172221
}
173222
}

0 commit comments

Comments
 (0)