Skip to content

Commit 85f628a

Browse files
committed
sync
1 parent 53cb90f commit 85f628a

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ pub mod eigh;
5151
pub mod norm;
5252
pub mod solve;
5353
pub mod cholesky;
54+
55+
pub mod topology;

src/topology.rs

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//! module for topologcal vector space
2+
//!
3+
4+
use std::iter::Sum;
5+
use ndarray::{ArrayBase, Data, Dimension, LinalgScalar, IntoDimension};
6+
use num_traits::Float;
7+
8+
pub fn inf<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String>
9+
where A: LinalgScalar + Squared<Output = Distance>,
10+
Distance: Float + Sum,
11+
S1: Data<Elem = A>,
12+
S2: Data<Elem = A>,
13+
D: Dimension
14+
{
15+
if a.shape() != b.shape() {
16+
return Err("Shapes are different".into());
17+
}
18+
let mut max_tol = Distance::zero();
19+
for (idx, val) in a.indexed_iter() {
20+
let t = b[idx.into_dimension()];
21+
let tol = (*val - t).sq_abs();
22+
if tol > max_tol {
23+
max_tol = tol;
24+
}
25+
}
26+
Ok(max_tol)
27+
}
28+
29+
pub fn l1<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String>
30+
where A: LinalgScalar + Squared<Output = Distance>,
31+
Distance: Float + Sum,
32+
S1: Data<Elem = A>,
33+
S2: Data<Elem = A>,
34+
D: Dimension
35+
{
36+
if a.shape() != b.shape() {
37+
return Err("Shapes are different".into());
38+
}
39+
Ok(a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).sq_abs()).sum())
40+
}
41+
42+
pub fn l2<A, Distance, S1, S2, D>(a: &ArrayBase<S1, D>, b: &ArrayBase<S2, D>) -> Result<Distance, String>
43+
where A: LinalgScalar + Squared<Output = Distance>,
44+
Distance: Float + Sum,
45+
S1: Data<Elem = A>,
46+
S2: Data<Elem = A>,
47+
D: Dimension
48+
{
49+
if a.shape() != b.shape() {
50+
return Err("Shapes are different".into());
51+
}
52+
Ok(a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).squared()).sum::<Distance>().sqrt())
53+
}
54+
55+
#[derive(Debug)]
56+
pub enum NotCloseError<Tol> {
57+
ShapeMismatch(String),
58+
LargeDeviation(Tol),
59+
}
60+
61+
pub fn all_close_inf<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>,
62+
b: &ArrayBase<S2, D>,
63+
atol: Tol)
64+
-> Result<Tol, NotCloseError<Tol>>
65+
where A: LinalgScalar + Squared<Output = Tol>,
66+
Tol: Float + Sum,
67+
S1: Data<Elem = A>,
68+
S2: Data<Elem = A>,
69+
D: Dimension
70+
{
71+
if a.shape() != b.shape() {
72+
return Err(NotCloseError::ShapeMismatch("Shapes are different".into()));
73+
}
74+
let mut max_tol = Tol::zero();
75+
for (idx, val) in a.indexed_iter() {
76+
let t = b[idx.into_dimension()];
77+
let tol = (*val - t).sq_abs();
78+
if tol > atol {
79+
return Err(NotCloseError::LargeDeviation(tol));
80+
}
81+
if tol > max_tol {
82+
max_tol = tol;
83+
}
84+
}
85+
Ok(max_tol)
86+
}
87+
88+
pub fn all_close_l1<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>,
89+
b: &ArrayBase<S2, D>,
90+
rtol: Tol)
91+
-> Result<Tol, NotCloseError<Tol>>
92+
where A: LinalgScalar + Squared<Output = Tol>,
93+
Tol: Float + Sum,
94+
S1: Data<Elem = A>,
95+
S2: Data<Elem = A>,
96+
D: Dimension
97+
{
98+
if a.shape() != b.shape() {
99+
return Err(NotCloseError::ShapeMismatch("Shapes are different".into()));
100+
}
101+
let nrm: Tol = b.iter().map(|x| x.sq_abs()).sum();
102+
let dev: Tol = a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).sq_abs()).sum();
103+
if dev / nrm > rtol {
104+
Err(NotCloseError::LargeDeviation(dev / nrm))
105+
} else {
106+
Ok(dev / nrm)
107+
}
108+
}
109+
110+
pub fn all_close_l2<A, Tol, S1, S2, D>(a: &ArrayBase<S1, D>,
111+
b: &ArrayBase<S2, D>,
112+
rtol: Tol)
113+
-> Result<Tol, NotCloseError<Tol>>
114+
where A: LinalgScalar + Squared<Output = Tol>,
115+
Tol: Float + Sum,
116+
S1: Data<Elem = A>,
117+
S2: Data<Elem = A>,
118+
D: Dimension
119+
{
120+
if a.shape() != b.shape() {
121+
return Err(NotCloseError::ShapeMismatch("Shapes are different".into()));
122+
}
123+
let nrm: Tol = b.iter().map(|x| x.squared()).sum();
124+
let dev: Tol = a.indexed_iter().map(|(idx, val)| (b[idx.into_dimension()] - *val).squared()).sum();
125+
let d = (dev / nrm).sqrt();
126+
if d > rtol {
127+
Err(NotCloseError::LargeDeviation(d))
128+
} else {
129+
Ok(d)
130+
}
131+
}
132+
133+
pub trait Squared {
134+
type Output;
135+
fn squared(&self) -> Self::Output;
136+
fn sq_abs(&self) -> Self::Output;
137+
}
138+
139+
impl<A: Float> Squared for A {
140+
type Output = A;
141+
fn squared(&self) -> A {
142+
*self * *self
143+
}
144+
fn sq_abs(&self) -> A {
145+
self.abs()
146+
}
147+
}

0 commit comments

Comments
 (0)