Skip to content

Commit 28facde

Browse files
committed
Add test for solve_upper
1 parent 53ccc1e commit 28facde

File tree

3 files changed

+66
-12
lines changed

3 files changed

+66
-12
lines changed

src/solve.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub trait ImplSolve: Sized {
2121
fn solve_triangle(layout: Layout,
2222
uplo: u8,
2323
size: usize,
24-
a: &Vec<Self>,
24+
a: &[Self],
2525
b: Vec<Self>)
2626
-> Result<Vec<Self>, LapackError>;
2727
}
@@ -65,10 +65,14 @@ impl ImplSolve for $scalar {
6565
Err(From::from(info))
6666
}
6767
}
68-
fn solve_triangle(layout: Layout, uplo: u8, size: usize, a: &Vec<Self>, mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
68+
fn solve_triangle(layout: Layout, uplo: u8, size: usize, a: &[Self], mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
6969
let n = size as i32;
7070
let lda = n;
71-
let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, 1, a, lda, &mut b, n);
71+
let ldb = match layout {
72+
Layout::ColumnMajor => n,
73+
Layout::RowMajor => 1,
74+
};
75+
let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, 1, a, lda, &mut b, ldb);
7276
if info == 0 {
7377
Ok(b)
7478
} else {

src/triangular.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use solve::ImplSolve;
1414
pub trait TriangularMatrix: Matrix + SquareMatrix {
1515
/// solve a triangular system with upper triangular matrix
1616
fn solve_upper(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
17-
/// solve a triangular system with lower triangular matrix
18-
fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
17+
// /// solve a triangular system with lower triangular matrix
18+
// fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
1919
}
2020

2121
impl<A> TriangularMatrix for Array<A, Ix2>
@@ -25,12 +25,8 @@ impl<A> TriangularMatrix for Array<A, Ix2>
2525
self.check_square()?;
2626
let (n, _) = self.size();
2727
let layout = self.layout()?;
28-
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, self.as_slice().unwrap(), b)?;
29-
}
30-
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
31-
self.check_square()?;
32-
let (n, _) = self.size();
33-
let layout = self.layout()?;
34-
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, self.as_slice().unwrap(), b)?;
28+
let a = self.as_slice_memory_order().unwrap();
29+
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.into_raw_vec())?;
30+
Ok(Array::from_vec(x))
3531
}
3632
}

tests/triangular.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
extern crate rand;
3+
extern crate ndarray;
4+
extern crate ndarray_rand;
5+
extern crate ndarray_linalg;
6+
7+
use ndarray::prelude::*;
8+
use ndarray_linalg::prelude::*;
9+
use rand::distributions::*;
10+
use ndarray_rand::RandomExt;
11+
12+
fn all_close<D: Dimension>(a: Array<f64, D>, b: Array<f64, D>) {
13+
if !a.all_close(&b, 1.0e-7) {
14+
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
15+
a,
16+
b);
17+
}
18+
}
19+
20+
#[test]
21+
fn solve_upper() {
22+
let r_dist = Range::new(0., 1.);
23+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
24+
for ((i, j), val) in a.indexed_iter_mut() {
25+
if i > j {
26+
*val = 0.0;
27+
}
28+
}
29+
println!("a = \n{:?}", &a);
30+
let b = Array::<f64, _>::random(3, r_dist);
31+
println!("b = \n{:?}", &b);
32+
let x = a.solve_upper(b.clone()).unwrap();
33+
println!("x = \n{:?}", &x);
34+
println!("Ax = \n{:?}", a.dot(&x));
35+
all_close(a.dot(&x), b);
36+
}
37+
38+
#[test]
39+
fn solve_upper_t() {
40+
let r_dist = Range::new(0., 1.);
41+
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
42+
for ((i, j), val) in a.indexed_iter_mut() {
43+
if i > j {
44+
*val = 0.0;
45+
}
46+
}
47+
println!("a = \n{:?}", &a);
48+
let b = Array::<f64, _>::random(3, r_dist);
49+
println!("b = \n{:?}", &b);
50+
let x = a.solve_upper(b.clone()).unwrap();
51+
println!("x = \n{:?}", &x);
52+
println!("Ax = \n{:?}", a.dot(&x));
53+
all_close(a.dot(&x), b);
54+
}

0 commit comments

Comments
 (0)