Skip to content

Commit d14e44b

Browse files
committed
Merge normalize_{rows,columns}
1 parent 5568b80 commit d14e44b

File tree

2 files changed

+15
-29
lines changed

2 files changed

+15
-29
lines changed

src/util.rs

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,38 +46,24 @@ pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
4646
stack(Axis(0), &views)
4747
}
4848

49-
/// normalize columns in L2 norm
50-
pub fn normalize_columns<A, S, T>(m: &ArrayBase<S, Ix2>) -> (Array2<A>, Vec<T>)
51-
where S: Data<Elem = A>,
52-
A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
53-
T: Float + Sum
54-
{
55-
let mut ms = Vec::new();
56-
let vs = m.axis_iter(Axis(1))
57-
.map(|v| {
58-
let n = v.norm();
59-
ms.push(n);
60-
v.mapv(|x| x / n)
61-
})
62-
.collect::<Vec<_>>();
63-
(hstack(&vs).unwrap(), ms)
49+
pub enum NormalizeAxis {
50+
Row = 0,
51+
Column = 1,
6452
}
6553

66-
/// normalize rows in L2 norm
67-
pub fn normalize_rows<A, S, T>(m: &ArrayBase<S, Ix2>) -> (Vec<T>, Array2<A>)
68-
where S: Data<Elem = A>,
69-
A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
54+
/// normalize in L2 norm
55+
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
56+
where A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
57+
S: DataMut<Elem = A>,
7058
T: Float + Sum
7159
{
7260
let mut ms = Vec::new();
73-
let vs = m.axis_iter(Axis(0))
74-
.map(|v| {
75-
let n = v.norm();
76-
ms.push(n);
77-
v.mapv(|x| x / n)
78-
})
79-
.collect::<Vec<_>>();
80-
(ms, vstack(&vs).unwrap())
61+
for mut v in m.axis_iter_mut(Axis(axis as usize)) {
62+
let n = v.norm();
63+
ms.push(n);
64+
v.map_inplace(|x| *x = *x / n)
65+
}
66+
(m, ms)
8167
}
8268

8369
/// check two arrays are close in maximum norm

tests/normalize.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ include!("header.rs");
33
#[test]
44
fn n_columns() {
55
let a = random_owned(3, 2, true);
6-
let (n, v) = normalize_columns(&a);
6+
let (n, v) = normalize(a.clone(), NormalizeAxis::Column);
77
all_close_l2(&n.dot(&from_diag(&v)), &a, 1e-7).unwrap();
88
}
99

1010
#[test]
1111
fn n_rows() {
1212
let a = random_owned(3, 2, true);
13-
let (v, n) = normalize_rows(&a);
13+
let (n, v) = normalize(a.clone(), NormalizeAxis::Row);
1414
all_close_l2(&from_diag(&v).dot(&n), &a, 1e-7).unwrap();
1515
}

0 commit comments

Comments
 (0)