Skip to content

Commit eb2b0c8

Browse files
committed
Reimplement all_close_{l1,l2,max} using Vector trait
1 parent 6ac551c commit eb2b0c8

File tree

1 file changed

+13
-117
lines changed

1 file changed

+13
-117
lines changed

src/topology.rs

Lines changed: 13 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -2,146 +2,42 @@
22
//!
33
44
use std::iter::Sum;
5-
use ndarray::{ArrayBase, Data, Dimension, LinalgScalar, IntoDimension};
5+
use ndarray::{ArrayBase, Data, Dimension, LinalgScalar};
66
use num_traits::Float;
7+
use super::vector::*;
78

8-
pub fn inf<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String>
9-
where A: LinalgScalar + Squared<Output = Distance>,
10-
Distance: Float + Sum,
11-
S1: Data<Elem = A>,
12-
S2: Data<Elem = A>,
13-
D: Dimension
14-
{
15-
if a.shape() != b.shape() {
16-
return Err("Shapes are different".into());
17-
}
18-
let mut max_tol = Distance::zero();
19-
for (idx, val) in a.indexed_iter() {
20-
let t = b[idx.into_dimension()];
21-
let tol = (*val - t).sq_abs();
22-
if tol > max_tol {
23-
max_tol = tol;
24-
}
25-
}
26-
Ok(max_tol)
27-
}
28-
29-
pub fn l1<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String>
30-
where A: LinalgScalar + Squared<Output = Distance>,
31-
Distance: Float + Sum,
32-
S1: Data<Elem = A>,
33-
S2: Data<Elem = A>,
34-
D: Dimension
35-
{
36-
if a.shape() != b.shape() {
37-
return Err("Shapes are different".into());
38-
}
39-
Ok(a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).sq_abs()).sum())
40-
}
41-
42-
pub fn l2<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String>
43-
where A: LinalgScalar + Squared<Output = Distance>,
44-
Distance: Float + Sum,
45-
S1: Data<Elem = A>,
46-
S2: Data<Elem = A>,
47-
D: Dimension
48-
{
49-
if a.shape() != b.shape() {
50-
return Err("Shapes are different".into());
51-
}
52-
Ok(a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).squared()).sum::<Distance>().sqrt())
53-
}
54-
55-
#[derive(Debug)]
56-
pub enum NotCloseError<Tol> {
57-
ShapeMismatch(String),
58-
LargeDeviation(Tol),
59-
}
60-
61-
pub fn all_close_inf<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>,
62-
b: &ArrayBase<S2, D>,
9+
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
10+
truth: &ArrayBase<S2, D>,
6311
atol: Tol)
64-
-> Result<Tol, NotCloseError<Tol>>
12+
-> Result<Tol, Tol>
6513
where A: LinalgScalar + Squared<Output = Tol>,
6614
Tol: Float + Sum,
6715
S1: Data<Elem = A>,
6816
S2: Data<Elem = A>,
6917
D: Dimension
7018
{
71-
if a.shape() != b.shape() {
72-
return Err(NotCloseError::ShapeMismatch("Shapes are different".into()));
73-
}
74-
let mut max_tol = Tol::zero();
75-
for (idx, val) in a.indexed_iter() {
76-
let t = b[idx.into_dimension()];
77-
let tol = (*val - t).sq_abs();
78-
if tol > atol {
79-
return Err(NotCloseError::LargeDeviation(tol));
80-
}
81-
if tol > max_tol {
82-
max_tol = tol;
83-
}
84-
}
85-
Ok(max_tol)
19+
let tol = (test - truth).norm_max();
20+
if tol < atol { Ok(tol) } else { Err(tol) }
8621
}
8722

88-
pub fn all_close_l1<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>,
89-
b: &ArrayBase<S2, D>,
90-
rtol: Tol)
91-
-> Result<Tol, NotCloseError<Tol>>
23+
pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
9224
where A: LinalgScalar + Squared<Output = Tol>,
9325
Tol: Float + Sum,
9426
S1: Data<Elem = A>,
9527
S2: Data<Elem = A>,
9628
D: Dimension
9729
{
98-
if a.shape() != b.shape() {
99-
return Err(NotCloseError::ShapeMismatch("Shapes are different".into()));
100-
}
101-
let nrm: Tol = b.iter().map(|x| x.sq_abs()).sum();
102-
let dev: Tol = a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).sq_abs()).sum();
103-
if dev / nrm > rtol {
104-
Err(NotCloseError::LargeDeviation(dev / nrm))
105-
} else {
106-
Ok(dev / nrm)
107-
}
30+
let tol = (test - truth).norm_l1() / truth.norm_l1();
31+
if tol < rtol { Ok(tol) } else { Err(tol) }
10832
}
10933

110-
pub fn all_close_l2<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>,
111-
b: &ArrayBase<S2, D>,
112-
rtol: Tol)
113-
-> Result<Tol, NotCloseError<Tol>>
34+
pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
11435
where A: LinalgScalar + Squared<Output = Tol>,
11536
Tol: Float + Sum,
11637
S1: Data<Elem = A>,
11738
S2: Data<Elem = A>,
11839
D: Dimension
11940
{
120-
if a.shape() != b.shape() {
121-
return Err(NotCloseError::ShapeMismatch("Shapes are different".into()));
122-
}
123-
let nrm: Tol = b.iter().map(|x| x.squared()).sum();
124-
let dev: Tol = a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).squared()).sum();
125-
let d = (dev / nrm).sqrt();
126-
if d > rtol {
127-
Err(NotCloseError::LargeDeviation(d))
128-
} else {
129-
Ok(d)
130-
}
131-
}
132-
133-
pub trait Squared {
134-
type Output;
135-
fn squared(&self) -> Self::Output;
136-
fn sq_abs(&self) -> Self::Output;
137-
}
138-
139-
impl<A: Float> Squared for A {
140-
type Output = A;
141-
fn squared(&self) -> A {
142-
*self * *self
143-
}
144-
fn sq_abs(&self) -> A {
145-
self.abs()
146-
}
41+
let tol = (test - truth).norm_l2() / truth.norm_l2();
42+
if tol < rtol { Ok(tol) } else { Err(tol) }
14743
}

0 commit comments

Comments
 (0)