@@ -46,38 +46,24 @@ pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
46
46
stack ( Axis ( 0 ) , & views)
47
47
}
48
48
49
- /// normalize columns in L2 norm
50
- pub fn normalize_columns < A , S , T > ( m : & ArrayBase < S , Ix2 > ) -> ( Array2 < A > , Vec < T > )
51
- where S : Data < Elem = A > ,
52
- A : LinalgScalar + NormedField < Output = T > + Div < T , Output = A > ,
53
- T : Float + Sum
54
- {
55
- let mut ms = Vec :: new ( ) ;
56
- let vs = m. axis_iter ( Axis ( 1 ) )
57
- . map ( |v| {
58
- let n = v. norm ( ) ;
59
- ms. push ( n) ;
60
- v. mapv ( |x| x / n)
61
- } )
62
- . collect :: < Vec < _ > > ( ) ;
63
- ( hstack ( & vs) . unwrap ( ) , ms)
49
+ pub enum NormalizeAxis {
50
+ Row = 0 ,
51
+ Column = 1 ,
64
52
}
65
53
66
- /// normalize rows in L2 norm
67
- pub fn normalize_rows < A , S , T > ( m : & ArrayBase < S , Ix2 > ) -> ( Vec < T > , Array2 < A > )
68
- where S : Data < Elem = A > ,
69
- A : LinalgScalar + NormedField < Output = T > + Div < T , Output = A > ,
54
+ /// normalize in L2 norm
55
+ pub fn normalize < A , S , T > ( mut m : ArrayBase < S , Ix2 > , axis : NormalizeAxis ) -> ( ArrayBase < S , Ix2 > , Vec < T > )
56
+ where A : LinalgScalar + NormedField < Output = T > + Div < T , Output = A > ,
57
+ S : DataMut < Elem = A > ,
70
58
T : Float + Sum
71
59
{
72
60
let mut ms = Vec :: new ( ) ;
73
- let vs = m. axis_iter ( Axis ( 0 ) )
74
- . map ( |v| {
75
- let n = v. norm ( ) ;
76
- ms. push ( n) ;
77
- v. mapv ( |x| x / n)
78
- } )
79
- . collect :: < Vec < _ > > ( ) ;
80
- ( ms, vstack ( & vs) . unwrap ( ) )
61
+ for mut v in m. axis_iter_mut ( Axis ( axis as usize ) ) {
62
+ let n = v. norm ( ) ;
63
+ ms. push ( n) ;
64
+ v. map_inplace ( |x| * x = * x / n)
65
+ }
66
+ ( m, ms)
81
67
}
82
68
83
69
/// check two arrays are close in maximum norm
0 commit comments