Skip to content

Commit 3abe98f

Browse files
committed
Update impl of SolveTriangular
1 parent c0d6dbb commit 3abe98f

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

src/generate.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,18 @@ pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
1919
a
2020
}
2121

22+
/// Random vector
23+
pub fn random_vector<A, S>(n: usize) -> ArrayBase<S, Ix1>
24+
where A: RandNormal,
25+
S: DataOwned<Elem = A>
26+
{
27+
let mut rng = thread_rng();
28+
let v: Vec<A> = (0..n).map(|_| A::randn(&mut rng)).collect();
29+
ArrayBase::from_vec(v)
30+
}
31+
2232
/// Random matrix
23-
pub fn random<A, S>(n: usize, m: usize) -> ArrayBase<S, Ix2>
33+
pub fn random_matrix<A, S>(n: usize, m: usize) -> ArrayBase<S, Ix2>
2434
where A: RandNormal,
2535
S: DataOwned<Elem = A>
2636
{
@@ -34,7 +44,7 @@ pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
3444
where A: RandNormal,
3545
S: DataOwned<Elem = A>
3646
{
37-
random(n, n)
47+
random_matrix(n, n)
3848
}
3949

4050
/// Random Hermite matrix

src/triangular.rs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,31 @@ pub trait SolveTriangular<Rhs> {
1515
fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result<Self::Output>;
1616
}
1717

18-
impl<A, S, V> SolveTriangular<V> for ArrayBase<S, Ix2>
18+
impl<A, Si, So, D> SolveTriangular<ArrayBase<So, D>> for ArrayBase<Si, Ix2>
1919
where A: LapackScalar,
20-
S: Data<Elem = A>,
21-
V: AllocatedArrayMut<Elem = A>
20+
Si: Data<Elem = A>,
21+
So: DataMut<Elem = A>,
22+
D: Dimension,
23+
ArrayBase<So, D>: AllocatedArrayMut<Elem = A>
2224
{
23-
type Output = V;
25+
type Output = ArrayBase<So, D>;
2426

25-
fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: V) -> Result<Self::Output> {
27+
fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: ArrayBase<So, D>) -> Result<Self::Output> {
28+
self.solve_triangular(uplo, diag, &mut b)?;
29+
Ok(b)
30+
}
31+
}
32+
33+
impl<'a, A, Si, So, D> SolveTriangular<&'a mut ArrayBase<So, D>> for ArrayBase<Si, Ix2>
34+
where A: LapackScalar,
35+
Si: Data<Elem = A>,
36+
So: DataMut<Elem = A>,
37+
D: Dimension,
38+
ArrayBase<So, D>: AllocatedArrayMut<Elem = A>
39+
{
40+
type Output = &'a mut ArrayBase<So, D>;
41+
42+
fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: &'a mut ArrayBase<So, D>) -> Result<Self::Output> {
2643
let la = self.layout()?;
2744
let lb = b.layout()?;
2845
let a_ = self.as_allocated()?;
@@ -31,16 +48,19 @@ impl<A, S, V> SolveTriangular<V> for ArrayBase<S, Ix2>
3148
}
3249
}
3350

34-
pub fn drop_upper<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
35-
where S: DataMut<Elem = A>
51+
impl<'a, A, Si, So, D> SolveTriangular<&'a ArrayBase<So, D>> for ArrayBase<Si, Ix2>
52+
where A: LapackScalar + Copy,
53+
Si: Data<Elem = A>,
54+
So: DataMut<Elem = A> + DataOwned,
55+
D: Dimension,
56+
ArrayBase<So, D>: AllocatedArrayMut<Elem = A>
3657
{
37-
a.into_triangular(UPLO::Lower)
38-
}
58+
type Output = ArrayBase<So, D>;
3959

40-
pub fn drop_lower<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
41-
where S: DataMut<Elem = A>
42-
{
43-
a.into_triangular(UPLO::Upper)
60+
fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase<So, D>) -> Result<Self::Output> {
61+
let b = replicate(b);
62+
self.solve_triangular(uplo, diag, b)
63+
}
4464
}
4565

4666
pub trait IntoTriangular<T> {
@@ -81,3 +101,15 @@ impl<A, S> IntoTriangular<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
81101
self
82102
}
83103
}
104+
105+
pub fn drop_upper<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
106+
where S: DataMut<Elem = A>
107+
{
108+
a.into_triangular(UPLO::Lower)
109+
}
110+
111+
pub fn drop_lower<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
112+
where S: DataMut<Elem = A>
113+
{
114+
a.into_triangular(UPLO::Upper)
115+
}

0 commit comments

Comments
 (0)