Skip to content

Commit a144a54

Browse files
committed
Implement solve_lower
1 parent 28facde commit a144a54

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

src/triangular.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use solve::ImplSolve;
1414
pub trait TriangularMatrix: Matrix + SquareMatrix {
1515
/// solve a triangular system with upper triangular matrix
1616
fn solve_upper(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
17-
// /// solve a triangular system with lower triangular matrix
18-
// fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
17+
/// solve a triangular system with lower triangular matrix
18+
fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
1919
}
2020

2121
impl<A> TriangularMatrix for Array<A, Ix2>
@@ -29,4 +29,12 @@ impl<A> TriangularMatrix for Array<A, Ix2>
2929
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.into_raw_vec())?;
3030
Ok(Array::from_vec(x))
3131
}
32+
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
33+
self.check_square()?;
34+
let (n, _) = self.size();
35+
let layout = self.layout()?;
36+
let a = self.as_slice_memory_order().unwrap();
37+
let x = ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b.into_raw_vec())?;
38+
Ok(Array::from_vec(x))
39+
}
3240
}

tests/triangular.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,39 @@ fn solve_upper_t() {
5252
println!("Ax = \n{:?}", a.dot(&x));
5353
all_close(a.dot(&x), b);
5454
}
55+
56+
#[test]
57+
fn solve_lower() {
58+
let r_dist = Range::new(0., 1.);
59+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
60+
for ((i, j), val) in a.indexed_iter_mut() {
61+
if i < j {
62+
*val = 0.0;
63+
}
64+
}
65+
println!("a = \n{:?}", &a);
66+
let b = Array::<f64, _>::random(3, r_dist);
67+
println!("b = \n{:?}", &b);
68+
let x = a.solve_lower(b.clone()).unwrap();
69+
println!("x = \n{:?}", &x);
70+
println!("Ax = \n{:?}", a.dot(&x));
71+
all_close(a.dot(&x), b);
72+
}
73+
74+
#[test]
75+
fn solve_lower_t() {
76+
let r_dist = Range::new(0., 1.);
77+
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
78+
for ((i, j), val) in a.indexed_iter_mut() {
79+
if i < j {
80+
*val = 0.0;
81+
}
82+
}
83+
println!("a = \n{:?}", &a);
84+
let b = Array::<f64, _>::random(3, r_dist);
85+
println!("b = \n{:?}", &b);
86+
let x = a.solve_lower(b.clone()).unwrap();
87+
println!("x = \n{:?}", &x);
88+
println!("Ax = \n{:?}", a.dot(&x));
89+
all_close(a.dot(&x), b);
90+
}

0 commit comments

Comments
 (0)