Skip to content

Commit 50b796c

Browse files
committed
Add Solve trait and drop Transpose from solve (split to functions)
1 parent 190313c commit 50b796c

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

examples/solve.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn factorize() -> Result<(), error::LinalgError> {
1111
let f = a.factorize_into()?; // LU factorize A (A is consumed)
1212
for _ in 0..10 {
1313
let b: Array1<f64> = random(3);
14-
let x = f.solve(Transpose::No, b)?; // solve Ax=b using factorized L, U
14+
let _x = f.solve(&b)?; // solve Ax=b using factorized L, U
1515
}
1616
Ok(())
1717
}

src/solve.rs

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,89 @@ use super::types::*;
99

1010
pub use lapack_traits::{Pivot, Transpose};
1111

12+
pub trait Solve<A: Scalar> {
13+
fn solve<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
14+
let mut a = replicate(a);
15+
self.solve_mut(&mut a)?;
16+
Ok(a)
17+
}
18+
fn solve_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
19+
self.solve_mut(&mut a)?;
20+
Ok(a)
21+
}
22+
fn solve_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
23+
24+
fn solve_t<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
25+
let mut a = replicate(a);
26+
self.solve_t_mut(&mut a)?;
27+
Ok(a)
28+
}
29+
fn solve_t_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
30+
self.solve_t_mut(&mut a)?;
31+
Ok(a)
32+
}
33+
fn solve_t_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
34+
35+
fn solve_h<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
36+
let mut a = replicate(a);
37+
self.solve_h_mut(&mut a)?;
38+
Ok(a)
39+
}
40+
fn solve_h_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
41+
self.solve_h_mut(&mut a)?;
42+
Ok(a)
43+
}
44+
fn solve_h_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
45+
}
46+
1247
pub struct Factorized<S: Data> {
1348
pub a: ArrayBase<S, Ix2>,
1449
pub ipiv: Pivot,
1550
}
1651

17-
impl<A, S> Factorized<S>
52+
impl<A, S> Solve<A> for Factorized<S>
1853
where
1954
A: Scalar,
2055
S: Data<Elem = A>,
2156
{
22-
pub fn solve<Sb>(&self, t: Transpose, mut rhs: ArrayBase<Sb, Ix1>) -> Result<ArrayBase<Sb, Ix1>>
57+
fn solve_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
58+
where
59+
Sb: DataMut<Elem = A>,
60+
{
61+
unsafe {
62+
A::solve(
63+
self.a.square_layout()?,
64+
Transpose::No,
65+
self.a.as_allocated()?,
66+
&self.ipiv,
67+
rhs.as_slice_mut().unwrap(),
68+
)?
69+
};
70+
Ok(rhs)
71+
}
72+
fn solve_t_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
73+
where
74+
Sb: DataMut<Elem = A>,
75+
{
76+
unsafe {
77+
A::solve(
78+
self.a.square_layout()?,
79+
Transpose::Transpose,
80+
self.a.as_allocated()?,
81+
&self.ipiv,
82+
rhs.as_slice_mut().unwrap(),
83+
)?
84+
};
85+
Ok(rhs)
86+
}
87+
fn solve_h_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
2388
where
2489
Sb: DataMut<Elem = A>,
2590
{
2691
unsafe {
2792
A::solve(
2893
self.a.square_layout()?,
29-
t,
94+
Transpose::Hermite,
3095
self.a.as_allocated()?,
3196
&self.ipiv,
3297
rhs.as_slice_mut().unwrap(),

0 commit comments

Comments
 (0)