Skip to content

Commit 4d65780

Browse files
committed
Add IntoTriangular
1 parent ecd0643 commit 4d65780

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

src/cholesky.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11

22
use ndarray::*;
3+
use num_traits::Zero;
34

45
use super::error::*;
56
use super::layout::*;
7+
use super::triangular::IntoTriangular;
68

79
use impl2::LapackScalar;
810
pub use impl2::UPLO;
@@ -12,33 +14,33 @@ pub trait Cholesky<K> {
1214
}
1315

1416
impl<A, S> Cholesky<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
15-
where A: LapackScalar,
17+
where A: LapackScalar + Zero,
1618
S: DataMut<Elem = A>
1719
{
1820
fn cholesky(mut self, uplo: UPLO) -> Result<ArrayBase<S, Ix2>> {
1921
A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?;
20-
Ok(self)
22+
Ok(self.into_triangular(uplo))
2123
}
2224
}
2325

2426
impl<'a, A, S> Cholesky<&'a mut ArrayBase<S, Ix2>> for &'a mut ArrayBase<S, Ix2>
25-
where A: LapackScalar,
27+
where A: LapackScalar + Zero,
2628
S: DataMut<Elem = A>
2729
{
2830
fn cholesky(mut self, uplo: UPLO) -> Result<&'a mut ArrayBase<S, Ix2>> {
2931
A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?;
30-
Ok(self)
32+
Ok(self.into_triangular(uplo))
3133
}
3234
}
3335

3436
impl<'a, A, Si, So> Cholesky<ArrayBase<So, Ix2>> for &'a ArrayBase<Si, Ix2>
35-
where A: LapackScalar + Copy,
37+
where A: LapackScalar + Copy + Zero,
3638
Si: Data<Elem = A>,
3739
So: DataMut<Elem = A> + DataOwned
3840
{
3941
fn cholesky(self, uplo: UPLO) -> Result<ArrayBase<So, Ix2>> {
4042
let mut a = replicate(self);
4143
A::cholesky(a.square_layout()?, uplo, a.as_allocated_mut()?)?;
42-
Ok(a)
44+
Ok(a.into_triangular(uplo))
4345
}
4446
}

src/triangular.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
//! Define methods for triangular matrices
22
33
use ndarray::*;
4+
use num_traits::Zero;
5+
use super::impl2::UPLO;
6+
47
use super::matrix::{Matrix, MFloat};
58
use super::square::SquareMatrix;
69
use super::error::LinalgError;
@@ -87,7 +90,7 @@ impl<A: MFloat> SolveTriangular<RcArray<A, Ix2>> for RcArray<A, Ix2> {
8790
}
8891
}
8992

90-
pub fn drop_upper<A: NdFloat, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
93+
pub fn drop_upper<A: Zero, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
9194
where S: DataMut<Elem = A>
9295
{
9396
for ((i, j), val) in a.indexed_iter_mut() {
@@ -98,7 +101,7 @@ pub fn drop_upper<A: NdFloat, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
98101
a
99102
}
100103

101-
pub fn drop_lower<A: NdFloat, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
104+
pub fn drop_lower<A: Zero, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
102105
where S: DataMut<Elem = A>
103106
{
104107
for ((i, j), val) in a.indexed_iter_mut() {
@@ -108,3 +111,42 @@ pub fn drop_lower<A: NdFloat, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
108111
}
109112
a
110113
}
114+
115+
pub trait IntoTriangular<T> {
116+
fn into_triangular(self, UPLO) -> T;
117+
}
118+
119+
impl<'a, A, S> IntoTriangular<&'a mut ArrayBase<S, Ix2>> for &'a mut ArrayBase<S, Ix2>
120+
where A: Zero,
121+
S: DataMut<Elem = A>
122+
{
123+
fn into_triangular(self, uplo: UPLO) -> &'a mut ArrayBase<S, Ix2> {
124+
match uplo {
125+
UPLO::Upper => {
126+
for ((i, j), val) in self.indexed_iter_mut() {
127+
if i > j {
128+
*val = A::zero();
129+
}
130+
}
131+
}
132+
UPLO::Lower => {
133+
for ((i, j), val) in self.indexed_iter_mut() {
134+
if i < j {
135+
*val = A::zero();
136+
}
137+
}
138+
}
139+
}
140+
self
141+
}
142+
}
143+
144+
impl<A, S> IntoTriangular<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
145+
where A: Zero,
146+
S: DataMut<Elem = A>
147+
{
148+
fn into_triangular(mut self, uplo: UPLO) -> ArrayBase<S, Ix2> {
149+
(&mut self).into_triangular(uplo);
150+
self
151+
}
152+
}

0 commit comments

Comments
 (0)