Skip to content

Commit 7529766

Browse files
authored
Merge pull request #69 from termoshtt/revice_interfaces
Revise interface of `solve()`
2 parents 190313c + e785cdf commit 7529766

File tree

2 files changed

+106
-4
lines changed

2 files changed

+106
-4
lines changed

examples/solve.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,26 @@ extern crate ndarray_linalg;
55
use ndarray::*;
66
use ndarray_linalg::*;
77

8+
// Solve `Ax=b`
9+
fn solve() -> Result<(), error::LinalgError> {
10+
let a: Array2<f64> = random((3, 3));
11+
let b: Array1<f64> = random(3);
12+
let _x = a.solve(&b)?;
13+
Ok(())
14+
}
15+
816
// Solve `Ax=b` for many b with fixed A
917
fn factorize() -> Result<(), error::LinalgError> {
1018
let a: Array2<f64> = random((3, 3));
1119
let f = a.factorize_into()?; // LU factorize A (A is consumed)
1220
for _ in 0..10 {
1321
let b: Array1<f64> = random(3);
14-
let x = f.solve(Transpose::No, b)?; // solve Ax=b using factorized L, U
22+
let _x = f.solve_into(b)?; // solve Ax=b using factorized L, U
1523
}
1624
Ok(())
1725
}
1826

1927
fn main() {
28+
solve().unwrap();
2029
factorize().unwrap();
2130
}

src/solve.rs

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,89 @@ use super::types::*;
99

1010
pub use lapack_traits::{Pivot, Transpose};
1111

12+
pub trait Solve<A: Scalar> {
13+
fn solve<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
14+
let mut a = replicate(a);
15+
self.solve_mut(&mut a)?;
16+
Ok(a)
17+
}
18+
fn solve_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
19+
self.solve_mut(&mut a)?;
20+
Ok(a)
21+
}
22+
fn solve_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
23+
24+
fn solve_t<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
25+
let mut a = replicate(a);
26+
self.solve_t_mut(&mut a)?;
27+
Ok(a)
28+
}
29+
fn solve_t_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
30+
self.solve_t_mut(&mut a)?;
31+
Ok(a)
32+
}
33+
fn solve_t_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
34+
35+
fn solve_h<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
36+
let mut a = replicate(a);
37+
self.solve_h_mut(&mut a)?;
38+
Ok(a)
39+
}
40+
fn solve_h_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
41+
self.solve_h_mut(&mut a)?;
42+
Ok(a)
43+
}
44+
fn solve_h_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
45+
}
46+
1247
pub struct Factorized<S: Data> {
1348
pub a: ArrayBase<S, Ix2>,
1449
pub ipiv: Pivot,
1550
}
1651

17-
impl<A, S> Factorized<S>
52+
impl<A, S> Solve<A> for Factorized<S>
1853
where
1954
A: Scalar,
2055
S: Data<Elem = A>,
2156
{
22-
pub fn solve<Sb>(&self, t: Transpose, mut rhs: ArrayBase<Sb, Ix1>) -> Result<ArrayBase<Sb, Ix1>>
57+
fn solve_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
58+
where
59+
Sb: DataMut<Elem = A>,
60+
{
61+
unsafe {
62+
A::solve(
63+
self.a.square_layout()?,
64+
Transpose::No,
65+
self.a.as_allocated()?,
66+
&self.ipiv,
67+
rhs.as_slice_mut().unwrap(),
68+
)?
69+
};
70+
Ok(rhs)
71+
}
72+
fn solve_t_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
73+
where
74+
Sb: DataMut<Elem = A>,
75+
{
76+
unsafe {
77+
A::solve(
78+
self.a.square_layout()?,
79+
Transpose::Transpose,
80+
self.a.as_allocated()?,
81+
&self.ipiv,
82+
rhs.as_slice_mut().unwrap(),
83+
)?
84+
};
85+
Ok(rhs)
86+
}
87+
fn solve_h_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
2388
where
2489
Sb: DataMut<Elem = A>,
2590
{
2691
unsafe {
2792
A::solve(
2893
self.a.square_layout()?,
29-
t,
94+
Transpose::Hermite,
3095
self.a.as_allocated()?,
3196
&self.ipiv,
3297
rhs.as_slice_mut().unwrap(),
@@ -36,6 +101,34 @@ where
36101
}
37102
}
38103

104+
impl<A, S> Solve<A> for ArrayBase<S, Ix2>
105+
where
106+
A: Scalar,
107+
S: Data<Elem = A>,
108+
{
109+
fn solve_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
110+
where
111+
Sb: DataMut<Elem = A>,
112+
{
113+
let f = self.factorize()?;
114+
f.solve_mut(rhs)
115+
}
116+
fn solve_t_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
117+
where
118+
Sb: DataMut<Elem = A>,
119+
{
120+
let f = self.factorize()?;
121+
f.solve_t_mut(rhs)
122+
}
123+
fn solve_h_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
124+
where
125+
Sb: DataMut<Elem = A>,
126+
{
127+
let f = self.factorize()?;
128+
f.solve_h_mut(rhs)
129+
}
130+
}
131+
39132
impl<A, S> Factorized<S>
40133
where
41134
A: Scalar,

0 commit comments

Comments
 (0)