Skip to content

Commit c2b3fdc

Browse files
committed
Rewrite permutate w/o LAPACK
1 parent 5f7bbbd commit c2b3fdc

File tree

3 files changed

+36
-47
lines changed

3 files changed

+36
-47
lines changed

src/matrix.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ pub trait Matrix: Sized {
3333
fn qr(self) -> Result<(Self, Self), LapackError>;
3434
/// LU decomposition
3535
fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError>;
36-
/// permutate matrix
37-
fn permutate_column(self, p: &Self::Permutator) -> Self;
38-
/// permutate matrix
39-
fn permutate_row(self, p: &Self::Permutator) -> Self;
36+
/// permutate matrix (inplace)
37+
fn permutate(&mut self, p: &Self::Permutator);
38+
/// permutate matrix (outplace)
39+
fn permutated(mut self, p: &Self::Permutator) -> Self {
40+
self.permutate(p);
41+
self
42+
}
4043
}
4144

4245
impl<A> Matrix for Array<A, (Ix, Ix)>
@@ -186,17 +189,16 @@ impl<A> Matrix for Array<A, (Ix, Ix)>
186189
}
187190
}
188191
}
189-
fn permutate_column(self, p: &Self::Permutator) -> Self {
190-
let (n, m) = self.size();
191-
let pd = ImplSolve::permutate_column(self.layout(), n, m, self.clone().into_raw_vec(), p);
192-
match self.layout() {
193-
Layout::ColumnMajor => Array::from_vec(pd).into_shape((m, n)).unwrap().reversed_axes(),
194-
Layout::RowMajor => Array::from_vec(pd).into_shape((m, n)).unwrap(),
192+
fn permutate(&mut self, ipiv: &Self::Permutator) {
193+
let (_, m) = self.size();
194+
for (i, j_) in ipiv.iter().enumerate().rev() {
195+
let j = (j_ - 1) as usize;
196+
if i == j {
197+
continue;
198+
}
199+
for k in 0..m {
200+
self.swap((i, k), (j, k));
201+
}
195202
}
196203
}
197-
fn permutate_row(self, p: &Self::Permutator) -> Self {
198-
let (n, m) = self.size();
199-
let pd = ImplSolve::permutate_column(self.layout(), m, n, self.clone().into_raw_vec(), p);
200-
Array::from_vec(pd).into_shape((n, m)).unwrap()
201-
}
202204
}

src/solve.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ pub trait ImplSolve: Sized {
1212
n: usize,
1313
a: Vec<Self>)
1414
-> Result<(Vec<i32>, Vec<Self>), LapackError>;
15-
fn permutate_column(layout: Layout,
16-
m: usize,
17-
n: usize,
18-
a: Vec<Self>,
19-
p: &Vec<i32>)
20-
-> Vec<Self>;
2115
}
2216

2317
macro_rules! impl_solve {
@@ -51,13 +45,6 @@ impl ImplSolve for $scalar {
5145
Err(From::from(info))
5246
}
5347
}
54-
fn permutate_column(layout: Layout, m: usize, n: usize, mut a: Vec<Self>, p: &Vec<i32>) -> Vec<Self> {
55-
let n = n as i32;
56-
let m = m as i32;
57-
let k = p.len() as i32;
58-
$laswp(layout, n, &mut a, m, 1, k, p, -1);
59-
a
60-
}
6148
}
6249
}} // end macro_rules
6350

tests/lu.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn permutate() {
2121
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
2222
println!("a= \n{:?}", &a);
2323
let p = vec![2, 2, 3]; // replace 1-2
24-
let pa = a.permutate_column(&p);
24+
let pa = a.permutated(&p);
2525
println!("permutated = \n{:?}", &pa);
2626
all_close(pa, arr2(&[[4., 5., 6.], [1., 2., 3.], [7., 8., 9.]]))
2727
}
@@ -31,7 +31,7 @@ fn permutate_t() {
3131
let a = arr2(&[[1., 4., 7.], [2., 5., 8.], [3., 6., 9.]]).reversed_axes();
3232
println!("a= \n{:?}", &a);
3333
let p = vec![2, 2, 3]; // replace 1-2
34-
let pa = a.permutate_column(&p);
34+
let pa = a.permutated(&p);
3535
println!("permutated = \n{:?}", &pa);
3636
all_close(pa, arr2(&[[4., 5., 6.], [1., 2., 3.], [7., 8., 9.]]))
3737
}
@@ -42,7 +42,7 @@ fn permutate_3x4() {
4242
println!("a= \n{:?}", &a);
4343
let p = vec![1, 3, 3]; // replace 2-3
4444
println!("permutation = \n{:?}", &p);
45-
let pa = a.permutate_column(&p);
45+
let pa = a.permutated(&p);
4646
println!("permutated = \n{:?}", &pa);
4747
all_close(pa,
4848
arr2(&[[1., 4., 7., 10.], [3., 6., 9., 12.], [2., 5., 8., 11.]]));
@@ -54,7 +54,7 @@ fn permutate_3x4_t() {
5454
println!("a= \n{:?}", &a);
5555
let p = vec![1, 3, 3]; // replace 2-3
5656
println!("permutation = \n{:?}", &p);
57-
let pa = a.permutate_column(&p);
57+
let pa = a.permutated(&p);
5858
println!("permutated = \n{:?}", &pa);
5959
all_close(pa,
6060
arr2(&[[1., 2., 3., 4.], [9., 10., 11., 12.], [5., 6., 7., 8.]]));
@@ -66,7 +66,7 @@ fn permutate_4x3() {
6666
println!("a= \n{:?}", &a);
6767
let p = vec![4, 2, 3, 4]; // replace 1-4
6868
println!("permutation = \n{:?}", &p);
69-
let pa = a.permutate_column(&p);
69+
let pa = a.permutated(&p);
7070
println!("permutated = \n{:?}", &pa);
7171
all_close(pa,
7272
arr2(&[[4., 8., 12.], [2., 6., 10.], [3., 7., 11.], [1., 5., 9.]]))
@@ -78,7 +78,7 @@ fn permutate_4x3_t() {
7878
println!("a= \n{:?}", &a);
7979
let p = vec![4, 2, 3, 4]; // replace 1-4
8080
println!("permutation = \n{:?}", &p);
81-
let pa = a.permutate_column(&p);
81+
let pa = a.permutated(&p);
8282
println!("permutated = \n{:?}", &pa);
8383
all_close(pa,
8484
arr2(&[[10., 11., 12.], [4., 5., 6.], [7., 8., 9.], [1., 2., 3.]]))
@@ -98,7 +98,7 @@ fn lu_square_upper() {
9898
println!("P = \n{:?}", &p);
9999
println!("L = \n{:?}", &l);
100100
println!("U = \n{:?}", &u);
101-
all_close(l.dot(&u).permutate_column(&p), a);
101+
all_close(l.dot(&u).permutated(&p), a);
102102
}
103103

104104
#[test]
@@ -115,7 +115,7 @@ fn lu_square_upper_t() {
115115
println!("P = \n{:?}", &p);
116116
println!("L = \n{:?}", &l);
117117
println!("U = \n{:?}", &u);
118-
all_close(l.dot(&u).permutate_column(&p), a);
118+
all_close(l.dot(&u).permutated(&p), a);
119119
}
120120

121121
#[test]
@@ -133,7 +133,7 @@ fn lu_square_lower() {
133133
println!("L = \n{:?}", &l);
134134
println!("U = \n{:?}", &u);
135135
println!("LU = \n{:?}", l.dot(&u));
136-
all_close(l.dot(&u).permutate_column(&p), a);
136+
all_close(l.dot(&u).permutated(&p), a);
137137
}
138138

139139
#[test]
@@ -151,7 +151,7 @@ fn lu_square_lower_t() {
151151
println!("L = \n{:?}", &l);
152152
println!("U = \n{:?}", &u);
153153
println!("LU = \n{:?}", l.dot(&u));
154-
all_close(l.dot(&u).permutate_column(&p), a);
154+
all_close(l.dot(&u).permutated(&p), a);
155155
}
156156

157157
#[test]
@@ -164,7 +164,7 @@ fn lu_square() {
164164
println!("L = \n{:?}", &l);
165165
println!("U = \n{:?}", &u);
166166
println!("LU = \n{:?}", l.dot(&u));
167-
all_close(l.dot(&u).permutate_column(&p), a);
167+
all_close(l.dot(&u).permutated(&p), a);
168168
}
169169

170170
#[test]
@@ -176,7 +176,7 @@ fn lu_square_t() {
176176
println!("P = \n{:?}", &p);
177177
println!("L = \n{:?}", &l);
178178
println!("U = \n{:?}", &u);
179-
all_close(l.dot(&u).permutate_column(&p), a);
179+
all_close(l.dot(&u).permutated(&p), a);
180180
}
181181

182182
// #[test]
@@ -189,7 +189,7 @@ fn lu_square_t() {
189189
// println!("L = \n{:?}", &l);
190190
// println!("U = \n{:?}", &u);
191191
// println!("LU = \n{:?}", l.dot(&u));
192-
// all_close(l.dot(&u).permutate_column(&p), a);
192+
// all_close(l.dot(&u).permutated(&p), a);
193193
// }
194194
//
195195
// #[test]
@@ -201,7 +201,7 @@ fn lu_square_t() {
201201
// println!("P = \n{:?}", &p);
202202
// println!("L = \n{:?}", &l);
203203
// println!("U = \n{:?}", &u);
204-
// all_close(l.dot(&u).permutate_column(&p), a);
204+
// all_close(l.dot(&u).permutated(&p), a);
205205
// }
206206

207207
// #[test]
@@ -219,7 +219,7 @@ fn lu_square_t() {
219219
// println!("L = \n{:?}", &l);
220220
// println!("U = \n{:?}", &u);
221221
// println!("LU = \n{:?}", l.dot(&u));
222-
// all_close(l.dot(&u).permutate_column(&p), a);
222+
// all_close(l.dot(&u).permutated(&p), a);
223223
// }
224224
//
225225
// #[test]
@@ -237,7 +237,7 @@ fn lu_square_t() {
237237
// println!("L = \n{:?}", &l);
238238
// println!("U = \n{:?}", &u);
239239
// println!("LU = \n{:?}", l.dot(&u));
240-
// all_close(l.dot(&u).permutate_column(&p), a);
240+
// all_close(l.dot(&u).permutated(&p), a);
241241
// }
242242

243243
#[test]
@@ -255,7 +255,7 @@ fn lu_4x3_upper_t() {
255255
println!("L = \n{:?}", &l);
256256
println!("U = \n{:?}", &u);
257257
println!("LU = \n{:?}", l.dot(&u));
258-
all_close(l.dot(&u).permutate_column(&p), a);
258+
all_close(l.dot(&u).permutated(&p), a);
259259
}
260260

261261
// #[test]
@@ -268,7 +268,7 @@ fn lu_4x3_upper_t() {
268268
// println!("L = \n{:?}", &l);
269269
// println!("U = \n{:?}", &u);
270270
// println!("LU = \n{:?}", l.dot(&u));
271-
// all_close(l.dot(&u).permutate_column(&p), a);
271+
// all_close(l.dot(&u).permutated(&p), a);
272272
// }
273273
//
274274
// #[test]
@@ -280,5 +280,5 @@ fn lu_4x3_upper_t() {
280280
// println!("P = \n{:?}", &p);
281281
// println!("L = \n{:?}", &l);
282282
// println!("U = \n{:?}", &u);
283-
// all_close(l.dot(&u).permutate_column(&p), a);
283+
// all_close(l.dot(&u).permutated(&p), a);
284284
// }

0 commit comments

Comments
 (0)