|
2 | 2 | //!
|
3 | 3 |
|
4 | 4 | use std::iter::Sum;
|
5 |
| -use ndarray::{ArrayBase, Data, Dimension, LinalgScalar, IntoDimension}; |
| 5 | +use ndarray::{ArrayBase, Data, Dimension, LinalgScalar}; |
6 | 6 | use num_traits::Float;
|
| 7 | +use super::vector::*; |
7 | 8 |
|
8 |
| -pub fn inf<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String> |
9 |
| - where A: LinalgScalar + Squared<Output = Distance>, |
10 |
| - Distance: Float + Sum, |
11 |
| - S1: Data<Elem = A>, |
12 |
| - S2: Data<Elem = A>, |
13 |
| - D: Dimension |
14 |
| -{ |
15 |
| - if a.shape() != b.shape() { |
16 |
| - return Err("Shapes are different".into()); |
17 |
| - } |
18 |
| - let mut max_tol = Distance::zero(); |
19 |
| - for (idx, val) in a.indexed_iter() { |
20 |
| - let t = b[idx.into_dimension()]; |
21 |
| - let tol = (*val - t).sq_abs(); |
22 |
| - if tol > max_tol { |
23 |
| - max_tol = tol; |
24 |
| - } |
25 |
| - } |
26 |
| - Ok(max_tol) |
27 |
| -} |
28 |
| - |
29 |
| -pub fn l1<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String> |
30 |
| - where A: LinalgScalar + Squared<Output = Distance>, |
31 |
| - Distance: Float + Sum, |
32 |
| - S1: Data<Elem = A>, |
33 |
| - S2: Data<Elem = A>, |
34 |
| - D: Dimension |
35 |
| -{ |
36 |
| - if a.shape() != b.shape() { |
37 |
| - return Err("Shapes are different".into()); |
38 |
| - } |
39 |
| - Ok(a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).sq_abs()).sum()) |
40 |
| -} |
41 |
| - |
42 |
| -pub fn l2<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String> |
43 |
| - where A: LinalgScalar + Squared<Output = Distance>, |
44 |
| - Distance: Float + Sum, |
45 |
| - S1: Data<Elem = A>, |
46 |
| - S2: Data<Elem = A>, |
47 |
| - D: Dimension |
48 |
| -{ |
49 |
| - if a.shape() != b.shape() { |
50 |
| - return Err("Shapes are different".into()); |
51 |
| - } |
52 |
| - Ok(a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).squared()).sum::<Distance>().sqrt()) |
53 |
| -} |
54 |
| - |
55 |
| -#[derive(Debug)] |
56 |
| -pub enum NotCloseError<Tol> { |
57 |
| - ShapeMismatch(String), |
58 |
| - LargeDeviation(Tol), |
59 |
| -} |
60 |
| - |
61 |
| -pub fn all_close_inf<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>, |
62 |
| - b: &ArrayBase<S2, D>, |
| 9 | +pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, |
| 10 | + truth: &ArrayBase<S2, D>, |
63 | 11 | atol: Tol)
|
64 |
| - -> Result<Tol, NotCloseError<Tol>> |
| 12 | + -> Result<Tol, Tol> |
65 | 13 | where A: LinalgScalar + Squared<Output = Tol>,
|
66 | 14 | Tol: Float + Sum,
|
67 | 15 | S1: Data<Elem = A>,
|
68 | 16 | S2: Data<Elem = A>,
|
69 | 17 | D: Dimension
|
70 | 18 | {
|
71 |
| - if a.shape() != b.shape() { |
72 |
| - return Err(NotCloseError::ShapeMismatch("Shapes are different".into())); |
73 |
| - } |
74 |
| - let mut max_tol = Tol::zero(); |
75 |
| - for (idx, val) in a.indexed_iter() { |
76 |
| - let t = b[idx.into_dimension()]; |
77 |
| - let tol = (*val - t).sq_abs(); |
78 |
| - if tol > atol { |
79 |
| - return Err(NotCloseError::LargeDeviation(tol)); |
80 |
| - } |
81 |
| - if tol > max_tol { |
82 |
| - max_tol = tol; |
83 |
| - } |
84 |
| - } |
85 |
| - Ok(max_tol) |
| 19 | + let tol = (test - truth).norm_max(); |
| 20 | + if tol < atol { Ok(tol) } else { Err(tol) } |
86 | 21 | }
|
87 | 22 |
|
88 |
| -pub fn all_close_l1<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>, |
89 |
| - b: &ArrayBase<S2, D>, |
90 |
| - rtol: Tol) |
91 |
| - -> Result<Tol, NotCloseError<Tol>> |
| 23 | +pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol> |
92 | 24 | where A: LinalgScalar + Squared<Output = Tol>,
|
93 | 25 | Tol: Float + Sum,
|
94 | 26 | S1: Data<Elem = A>,
|
95 | 27 | S2: Data<Elem = A>,
|
96 | 28 | D: Dimension
|
97 | 29 | {
|
98 |
| - if a.shape() != b.shape() { |
99 |
| - return Err(NotCloseError::ShapeMismatch("Shapes are different".into())); |
100 |
| - } |
101 |
| - let nrm: Tol = b.iter().map(|x| x.sq_abs()).sum(); |
102 |
| - let dev: Tol = a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).sq_abs()).sum(); |
103 |
| - if dev / nrm > rtol { |
104 |
| - Err(NotCloseError::LargeDeviation(dev / nrm)) |
105 |
| - } else { |
106 |
| - Ok(dev / nrm) |
107 |
| - } |
| 30 | + let tol = (test - truth).norm_l1() / truth.norm_l1(); |
| 31 | + if tol < rtol { Ok(tol) } else { Err(tol) } |
108 | 32 | }
|
109 | 33 |
|
110 |
| -pub fn all_close_l2<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>, |
111 |
| - b: &ArrayBase<S2, D>, |
112 |
| - rtol: Tol) |
113 |
| - -> Result<Tol, NotCloseError<Tol>> |
| 34 | +pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol> |
114 | 35 | where A: LinalgScalar + Squared<Output = Tol>,
|
115 | 36 | Tol: Float + Sum,
|
116 | 37 | S1: Data<Elem = A>,
|
117 | 38 | S2: Data<Elem = A>,
|
118 | 39 | D: Dimension
|
119 | 40 | {
|
120 |
| - if a.shape() != b.shape() { |
121 |
| - return Err(NotCloseError::ShapeMismatch("Shapes are different".into())); |
122 |
| - } |
123 |
| - let nrm: Tol = b.iter().map(|x| x.squared()).sum(); |
124 |
| - let dev: Tol = a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).squared()).sum(); |
125 |
| - let d = (dev / nrm).sqrt(); |
126 |
| - if d > rtol { |
127 |
| - Err(NotCloseError::LargeDeviation(d)) |
128 |
| - } else { |
129 |
| - Ok(d) |
130 |
| - } |
131 |
| -} |
132 |
| - |
133 |
| -pub trait Squared { |
134 |
| - type Output; |
135 |
| - fn squared(&self) -> Self::Output; |
136 |
| - fn sq_abs(&self) -> Self::Output; |
137 |
| -} |
138 |
| - |
139 |
| -impl<A: Float> Squared for A { |
140 |
| - type Output = A; |
141 |
| - fn squared(&self) -> A { |
142 |
| - *self * *self |
143 |
| - } |
144 |
| - fn sq_abs(&self) -> A { |
145 |
| - self.abs() |
146 |
| - } |
| 41 | + let tol = (test - truth).norm_l2() / truth.norm_l2(); |
| 42 | + if tol < rtol { Ok(tol) } else { Err(tol) } |
147 | 43 | }
|
0 commit comments