Skip to content

Commit b09146d

Browse files
committed
Implement normalize_{rows,columns}
1 parent bc0ed7a commit b09146d

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

src/util.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ use std::iter::Sum;
44
use ndarray::*;
55
use num_traits::Float;
66
use super::vector::*;
7+
use std::ops::Div;
78

89
/// stack vectors into matrix horizontally
910
pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeError>
10-
where A: NdFloat,
11+
where A: LinalgScalar,
1112
S: Data<Elem = A>
1213
{
1314
let views: Vec<_> = xs.iter()
@@ -21,7 +22,7 @@ pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
2122

2223
/// stack vectors into matrix vertically
2324
pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeError>
24-
where A: NdFloat,
25+
where A: LinalgScalar,
2526
S: Data<Elem = A>
2627
{
2728
let views: Vec<_> = xs.iter()
@@ -33,6 +34,40 @@ pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
3334
stack(Axis(0), &views)
3435
}
3536

37+
/// normalize columns in L2 norm
38+
pub fn normalize_columns<A, S, T>(m: &ArrayBase<S, Ix2>) -> (Array2<A>, Vec<T>)
39+
where S: Data<Elem = A>,
40+
A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
41+
T: Float + Sum
42+
{
43+
let mut ms = Vec::new();
44+
let vs = m.axis_iter(Axis(1))
45+
.map(|v| {
46+
let n = v.norm();
47+
ms.push(n);
48+
v.mapv(|x| x / n)
49+
})
50+
.collect::<Vec<_>>();
51+
(hstack(&vs).unwrap(), ms)
52+
}
53+
54+
/// normalize rows in L2 norm
55+
pub fn normalize_rows<A, S, T>(m: &ArrayBase<S, Ix2>) -> (Vec<T>, Array2<A>)
56+
where S: Data<Elem = A>,
57+
A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
58+
T: Float + Sum
59+
{
60+
let mut ms = Vec::new();
61+
let vs = m.axis_iter(Axis(0))
62+
.map(|v| {
63+
let n = v.norm();
64+
ms.push(n);
65+
v.mapv(|x| x / n)
66+
})
67+
.collect::<Vec<_>>();
68+
(ms, vstack(&vs).unwrap())
69+
}
70+
3671
/// check two arrays are close in maximum norm
3772
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
3873
truth: &ArrayBase<S2, D>,

0 commit comments

Comments
 (0)