Skip to content

Commit ad171af

Browse files
committed
Merge branch 'master' into ci_windows
2 parents 549e6d6 + 03fe1ca commit ad171af

File tree

7 files changed

+426
-41
lines changed

7 files changed

+426
-41
lines changed

src/assert.rs

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,87 @@
11
//! Assertions for array
22
33
use ndarray::*;
4+
use std::fmt::Debug;
45

56
use super::norm::*;
67
use super::types::*;
78

8-
/// check two values are close in terms of the relative torrence
9-
pub fn rclose<A: Scalar>(test: A, truth: A, rtol: A::Real) -> Result<A::Real, A::Real> {
9+
/// check two values are close in terms of the relative tolerance
10+
pub fn rclose<A: Scalar>(test: A, truth: A, rtol: A::Real) {
1011
let dev = (test - truth).abs() / truth.abs();
11-
if dev < rtol {
12-
Ok(dev)
13-
} else {
14-
Err(dev)
12+
if dev > rtol {
13+
eprintln!("==== Assetion Failed ====");
14+
eprintln!("Expected = {}", truth);
15+
eprintln!("Actual = {}", test);
16+
panic!("Too large deviation in relative tolerance: {}", dev);
1517
}
1618
}
1719

18-
/// check two values are close in terms of the absolute torrence
19-
pub fn aclose<A: Scalar>(test: A, truth: A, atol: A::Real) -> Result<A::Real, A::Real> {
20+
/// check two values are close in terms of the absolute tolerance
21+
pub fn aclose<A: Scalar>(test: A, truth: A, atol: A::Real) {
2022
let dev = (test - truth).abs();
21-
if dev < atol {
22-
Ok(dev)
23-
} else {
24-
Err(dev)
23+
if dev > atol {
24+
eprintln!("==== Assetion Failed ====");
25+
eprintln!("Expected = {}", truth);
26+
eprintln!("Actual = {}", test);
27+
panic!("Too large deviation in absolute tolerance: {}", dev);
2528
}
2629
}
2730

2831
/// check two arrays are close in maximum norm
29-
pub fn close_max<A, S1, S2, D>(
30-
test: &ArrayBase<S1, D>,
31-
truth: &ArrayBase<S2, D>,
32-
atol: A::Real,
33-
) -> Result<A::Real, A::Real>
32+
pub fn close_max<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: A::Real)
3433
where
3534
A: Scalar + Lapack,
3635
S1: Data<Elem = A>,
3736
S2: Data<Elem = A>,
3837
D: Dimension,
38+
D::Pattern: PartialEq + Debug,
3939
{
40+
assert_eq!(test.dim(), truth.dim());
4041
let tol = (test - truth).norm_max();
41-
if tol < atol {
42-
Ok(tol)
43-
} else {
44-
Err(tol)
42+
if tol > atol {
43+
eprintln!("==== Assetion Failed ====");
44+
eprintln!("Expected:\n{}", truth);
45+
eprintln!("Actual:\n{}", test);
46+
panic!("Too large deviation in maximum norm: {} > {}", tol, atol);
4547
}
4648
}
4749

4850
/// check two arrays are close in L1 norm
49-
pub fn close_l1<A, S1, S2, D>(
50-
test: &ArrayBase<S1, D>,
51-
truth: &ArrayBase<S2, D>,
52-
rtol: A::Real,
53-
) -> Result<A::Real, A::Real>
51+
pub fn close_l1<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
5452
where
5553
A: Scalar + Lapack,
5654
S1: Data<Elem = A>,
5755
S2: Data<Elem = A>,
5856
D: Dimension,
57+
D::Pattern: PartialEq + Debug,
5958
{
59+
assert_eq!(test.dim(), truth.dim());
6060
let tol = (test - truth).norm_l1() / truth.norm_l1();
61-
if tol < rtol {
62-
Ok(tol)
63-
} else {
64-
Err(tol)
61+
if tol > rtol {
62+
eprintln!("==== Assetion Failed ====");
63+
eprintln!("Expected:\n{}", truth);
64+
eprintln!("Actual:\n{}", test);
65+
panic!("Too large deviation in L1-norm: {} > {}", tol, rtol);
6566
}
6667
}
6768

6869
/// check two arrays are close in L2 norm
69-
pub fn close_l2<A, S1, S2, D>(
70-
test: &ArrayBase<S1, D>,
71-
truth: &ArrayBase<S2, D>,
72-
rtol: A::Real,
73-
) -> Result<A::Real, A::Real>
70+
pub fn close_l2<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
7471
where
7572
A: Scalar + Lapack,
7673
S1: Data<Elem = A>,
7774
S2: Data<Elem = A>,
7875
D: Dimension,
76+
D::Pattern: PartialEq + Debug,
7977
{
78+
assert_eq!(test.dim(), truth.dim());
8079
let tol = (test - truth).norm_l2() / truth.norm_l2();
81-
if tol < rtol {
82-
Ok(tol)
83-
} else {
84-
Err(tol)
80+
if tol > rtol {
81+
eprintln!("==== Assetion Failed ====");
82+
eprintln!("Expected:\n{}", truth);
83+
eprintln!("Actual:\n{}", test);
84+
panic!("Too large deviation in L2-norm: {} > {} ", tol, rtol);
8585
}
8686
}
8787

@@ -90,10 +90,11 @@ macro_rules! generate_assert {
9090
#[macro_export]
9191
macro_rules! $assert {
9292
($test: expr,$truth: expr,$tol: expr) => {
93-
$crate::$close($test, $truth, $tol).unwrap();
93+
$crate::$close($test, $truth, $tol);
9494
};
9595
($test: expr,$truth: expr,$tol: expr; $comment: expr) => {
96-
$crate::$close($test, $truth, $tol).expect($comment);
96+
eprintln!($comment);
97+
$crate::$close($test, $truth, $tol);
9798
};
9899
}
99100
};

src/inner.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use crate::types::*;
2+
use ndarray::*;
3+
4+
/// Inner Product
5+
///
6+
/// Differenct from `Dot` trait, this take complex conjugate of `self` elements
7+
///
8+
pub trait InnerProduct {
9+
type Elem: Scalar;
10+
11+
/// Inner product `(self.conjugate, rhs)
12+
fn inner<S>(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Elem
13+
where
14+
S: Data<Elem = Self::Elem>;
15+
}
16+
17+
impl<A, S> InnerProduct for ArrayBase<S, Ix1>
18+
where
19+
A: Scalar,
20+
S: Data<Elem = A>,
21+
{
22+
type Elem = A;
23+
fn inner<St: Data<Elem = A>>(&self, rhs: &ArrayBase<St, Ix1>) -> A {
24+
assert_eq!(self.len(), rhs.len());
25+
Zip::from(self)
26+
.and(rhs)
27+
.fold_while(A::zero(), |acc, s, r| FoldWhile::Continue(acc + s.conj() * *r))
28+
.into_inner()
29+
}
30+
}

src/krylov/mgs.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
//! Modified Gram-Schmit orthogonalizer
2+
3+
use super::*;
4+
use crate::{generate::*, inner::*, norm::Norm};
5+
6+
/// Iterative orthogonalizer using modified Gram-Schmit procedure
7+
///
8+
/// ```rust
9+
/// # use ndarray::*;
10+
/// # use ndarray_linalg::{krylov::*, *};
11+
/// let mut mgs = MGS::new(3);
12+
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
13+
/// close_l2(&coef, &array![1.0], 1e-9);
14+
///
15+
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
16+
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
17+
///
18+
/// // Fail if the vector is linearly dependent
19+
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
20+
///
21+
/// // You can get coefficients of dependent vector
22+
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
23+
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
24+
/// }
25+
/// ```
26+
#[derive(Debug, Clone)]
27+
pub struct MGS<A> {
28+
/// Dimension of base space
29+
dimension: usize,
30+
/// Basis of spanned space
31+
q: Vec<Array1<A>>,
32+
}
33+
34+
impl<A: Scalar> MGS<A> {
35+
/// Create an empty orthogonalizer
36+
pub fn new(dimension: usize) -> Self {
37+
Self {
38+
dimension,
39+
q: Vec::new(),
40+
}
41+
}
42+
}
43+
44+
impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
45+
type Elem = A;
46+
47+
fn dim(&self) -> usize {
48+
self.dimension
49+
}
50+
51+
fn len(&self) -> usize {
52+
self.q.len()
53+
}
54+
55+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
56+
where
57+
A: Lapack,
58+
S: DataMut<Elem = A>,
59+
{
60+
assert_eq!(a.len(), self.dim());
61+
let mut coef = Array1::zeros(self.len() + 1);
62+
for i in 0..self.len() {
63+
let q = &self.q[i];
64+
let c = q.inner(&a);
65+
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
66+
coef[i] = c;
67+
}
68+
let nrm = a.norm_l2();
69+
coef[self.len()] = A::from_real(nrm);
70+
coef
71+
}
72+
73+
fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
74+
where
75+
A: Lapack,
76+
S: Data<Elem = A>,
77+
{
78+
let mut a = a.into_owned();
79+
let coef = self.orthogonalize(&mut a);
80+
let nrm = coef[coef.len() - 1].re();
81+
if nrm < rtol {
82+
// Linearly dependent
83+
return Err(coef);
84+
}
85+
azip!(mut a in { *a = *a / A::from_real(nrm) });
86+
self.q.push(a);
87+
Ok(coef)
88+
}
89+
90+
fn get_q(&self) -> Q<A> {
91+
hstack(&self.q).unwrap()
92+
}
93+
}
94+
95+
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
96+
pub fn mgs<A, S>(
97+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
98+
dim: usize,
99+
rtol: A::Real,
100+
strategy: Strategy,
101+
) -> (Q<A>, R<A>)
102+
where
103+
A: Scalar + Lapack,
104+
S: Data<Elem = A>,
105+
{
106+
let mgs = MGS::new(dim);
107+
qr(iter, mgs, rtol, strategy)
108+
}

0 commit comments

Comments
 (0)