@@ -4,10 +4,11 @@ use std::iter::Sum;
4
4
use ndarray:: * ;
5
5
use num_traits:: Float ;
6
6
use super :: vector:: * ;
7
+ use std:: ops:: Div ;
7
8
8
9
/// stack vectors into matrix horizontally
9
10
pub fn hstack < A , S > ( xs : & [ ArrayBase < S , Ix1 > ] ) -> Result < Array < A , Ix2 > , ShapeError >
10
- where A : NdFloat ,
11
+ where A : LinalgScalar ,
11
12
S : Data < Elem = A >
12
13
{
13
14
let views: Vec < _ > = xs. iter ( )
@@ -21,7 +22,7 @@ pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
21
22
22
23
/// stack vectors into matrix vertically
23
24
pub fn vstack < A , S > ( xs : & [ ArrayBase < S , Ix1 > ] ) -> Result < Array < A , Ix2 > , ShapeError >
24
- where A : NdFloat ,
25
+ where A : LinalgScalar ,
25
26
S : Data < Elem = A >
26
27
{
27
28
let views: Vec < _ > = xs. iter ( )
@@ -33,6 +34,40 @@ pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeErro
33
34
stack ( Axis ( 0 ) , & views)
34
35
}
35
36
37
+ /// normalize columns in L2 norm
38
+ pub fn normalize_columns < A , S , T > ( m : & ArrayBase < S , Ix2 > ) -> ( Array2 < A > , Vec < T > )
39
+ where S : Data < Elem = A > ,
40
+ A : LinalgScalar + NormedField < Output = T > + Div < T , Output = A > ,
41
+ T : Float + Sum
42
+ {
43
+ let mut ms = Vec :: new ( ) ;
44
+ let vs = m. axis_iter ( Axis ( 1 ) )
45
+ . map ( |v| {
46
+ let n = v. norm ( ) ;
47
+ ms. push ( n) ;
48
+ v. mapv ( |x| x / n)
49
+ } )
50
+ . collect :: < Vec < _ > > ( ) ;
51
+ ( hstack ( & vs) . unwrap ( ) , ms)
52
+ }
53
+
54
+ /// normalize rows in L2 norm
55
+ pub fn normalize_rows < A , S , T > ( m : & ArrayBase < S , Ix2 > ) -> ( Vec < T > , Array2 < A > )
56
+ where S : Data < Elem = A > ,
57
+ A : LinalgScalar + NormedField < Output = T > + Div < T , Output = A > ,
58
+ T : Float + Sum
59
+ {
60
+ let mut ms = Vec :: new ( ) ;
61
+ let vs = m. axis_iter ( Axis ( 0 ) )
62
+ . map ( |v| {
63
+ let n = v. norm ( ) ;
64
+ ms. push ( n) ;
65
+ v. mapv ( |x| x / n)
66
+ } )
67
+ . collect :: < Vec < _ > > ( ) ;
68
+ ( ms, vstack ( & vs) . unwrap ( ) )
69
+ }
70
+
36
71
/// check two arrays are close in maximum norm
37
72
pub fn all_close_max < A , Tol , S1 , S2 , D > ( test : & ArrayBase < S1 , D > ,
38
73
truth : & ArrayBase < S2 , D > ,
0 commit comments