Skip to content

Commit c9085b2

Browse files
committed
Merge branch 'normalize'
2 parents bc0ed7a + d14e44b commit c9085b2

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/util.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,23 @@ use std::iter::Sum;
44
use ndarray::*;
55
use num_traits::Float;
66
use super::vector::*;
7+
use std::ops::Div;
8+
9+
/// construct matrix from diag
10+
pub fn from_diag<A>(d: &[A]) -> Array2<A>
11+
where A: LinalgScalar
12+
{
13+
let n = d.len();
14+
let mut e = Array::zeros((n, n));
15+
for i in 0..n {
16+
e[(i, i)] = d[i];
17+
}
18+
e
19+
}
720

821
/// stack vectors into matrix horizontally
922
pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeError>
10-
where A: NdFloat,
23+
where A: LinalgScalar,
1124
S: Data<Elem = A>
1225
{
1326
let views: Vec<_> = xs.iter()
@@ -21,7 +34,7 @@ pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
2134

2235
/// stack vectors into matrix vertically
2336
pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeError>
24-
where A: NdFloat,
37+
where A: LinalgScalar,
2538
S: Data<Elem = A>
2639
{
2740
let views: Vec<_> = xs.iter()
@@ -33,6 +46,26 @@ pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
3346
stack(Axis(0), &views)
3447
}
3548

49+
pub enum NormalizeAxis {
50+
Row = 0,
51+
Column = 1,
52+
}
53+
54+
/// normalize in L2 norm
55+
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
56+
where A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
57+
S: DataMut<Elem = A>,
58+
T: Float + Sum
59+
{
60+
let mut ms = Vec::new();
61+
for mut v in m.axis_iter_mut(Axis(axis as usize)) {
62+
let n = v.norm();
63+
ms.push(n);
64+
v.map_inplace(|x| *x = *x / n)
65+
}
66+
(m, ms)
67+
}
68+
3669
/// check two arrays are close in maximum norm
3770
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
3871
truth: &ArrayBase<S2, D>,

tests/normalize.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
include!("header.rs");
2+
3+
#[test]
4+
fn n_columns() {
5+
let a = random_owned(3, 2, true);
6+
let (n, v) = normalize(a.clone(), NormalizeAxis::Column);
7+
all_close_l2(&n.dot(&from_diag(&v)), &a, 1e-7).unwrap();
8+
}
9+
10+
#[test]
11+
fn n_rows() {
12+
let a = random_owned(3, 2, true);
13+
let (n, v) = normalize(a.clone(), NormalizeAxis::Row);
14+
all_close_l2(&from_diag(&v).dot(&n), &a, 1e-7).unwrap();
15+
}

0 commit comments

Comments
 (0)