33//! Each operation is implemented as a struct that implements the
44//! [Operation](crate::operation::Operation) trait.
55
6- use std:: cmp:: min_by;
6+ use std:: cmp:: { max_by , min_by} ;
77
88use crate :: array;
99use crate :: error:: ActiveStorageError ;
@@ -13,7 +13,6 @@ use crate::types::Missing;
1313
1414use axum:: body:: Bytes ;
1515use ndarray:: { ArrayView , Axis } ;
16- use ndarray_stats:: { errors:: MinMaxError , QuantileExt } ;
1716// Bring trait into scope to use as_bytes method.
1817use zerocopy:: AsBytes ;
1918
@@ -80,6 +79,72 @@ impl NumOperation for Count {
8079/// Return the maximum of selected elements in the array.
8180pub struct Max { }
8281
82+ fn max_element_pairwise < T : Element > ( x : & & T , y : & & T ) -> std:: cmp:: Ordering {
83+ // TODO: How to handle NaN correctly?
84+ // Numpy seems to behave as follows:
85+ //
86+ // np.min([np.nan, 1]) == np.nan
87+ // np.max([np.nan, 1]) == np.nan
88+ // np.nan != np.nan
89+ // np.min([np.nan, 1]) != np.max([np.nan, 1])
90+ //
91+ // There are also separate np.nan{min,max} functions
92+ // which ignore nans instead.
93+ //
94+ // Which behaviour do we want to follow?
95+ //
96+ // Panic for now (TODO: Make this a user-facing error response instead)
97+ x. partial_cmp ( y)
98+ // .unwrap_or(std::cmp::Ordering::Less)
99+ . unwrap_or_else ( || panic ! ( "unexpected undefined order error for min" ) )
100+ }
101+
102+ /// Performs a max over one or more axes of the provided array
103+ fn max_array_multi_axis < T : Element > (
104+ array : ndarray:: ArrayView < T , ndarray:: IxDyn > ,
105+ axes : & [ usize ] ,
106+ missing : Option < Missing < T > > ,
107+ ) -> ( Vec < T > , Vec < i64 > , Vec < usize > ) {
108+ // Find maximum over first axis and count elements operated on
109+ let init = T :: min_value ( ) ;
110+ let mut result = array
111+ . fold_axis ( Axis ( axes[ 0 ] ) , ( init, 0 ) , |( running_max, count) , val| {
112+ if let Some ( missing) = & missing {
113+ if !missing. is_missing ( val) {
114+ let new_max = max_by ( running_max, val, max_element_pairwise) ;
115+ ( * new_max, count + 1 )
116+ } else {
117+ ( * running_max, * count)
118+ }
119+ } else {
120+ let new_max = max_by ( running_max, val, max_element_pairwise) ;
121+ ( * new_max, count + 1 )
122+ }
123+ } )
124+ . into_dyn ( ) ;
125+ // Find max over remaining axes (where total count is now sum of counts)
126+ if let Some ( remaining_axes) = axes. get ( 1 ..) {
127+ for ( n, axis) in remaining_axes. iter ( ) . enumerate ( ) {
128+ result = result
129+ . fold_axis (
130+ Axis ( axis - n - 1 ) ,
131+ ( init, 0 ) ,
132+ |( global_max, total_count) , ( running_max, count) | {
133+ let new_max = max_by ( global_max, running_max, max_element_pairwise) ;
134+ ( * new_max, total_count + count)
135+ } ,
136+ )
137+ . into_dyn ( ) ;
138+ }
139+ }
140+
141+ // Result is array of (max, count) tuples so separate them here
142+ let maxes = result. iter ( ) . map ( |( max, _) | * max) . collect :: < Vec < T > > ( ) ;
143+ let counts = result. iter ( ) . map ( |( _, count) | * count) . collect :: < Vec < i64 > > ( ) ;
144+
145+ ( maxes, counts, result. shape ( ) . into ( ) )
146+ }
147+
83148impl NumOperation for Max {
84149 fn execute_t < T : Element > (
85150 request_data : & models:: RequestData ,
@@ -88,44 +153,74 @@ impl NumOperation for Max {
88153 let array = array:: build_array :: < T > ( request_data, & mut data) ?;
89154 let slice_info = array:: build_slice_info :: < T > ( & request_data. selection , array. shape ( ) ) ;
90155 let sliced = array. slice ( slice_info) ;
91- let ( max, count) = if let Some ( missing) = & request_data. missing {
92- let missing = Missing :: < T > :: try_from ( missing) ?;
93- // Use a fold to simultaneously max and count the non-missing data.
94- // TODO: separate float impl?
95- // TODO: inifinite/NaN
96- let ( max, count) = sliced
97- . iter ( )
98- . copied ( )
99- . filter ( missing_filter ( & missing) )
100- . fold ( ( None , 0 ) , |( a, count) , b| {
101- let max = match ( a, b) {
102- ( None , b) => Some ( b) , //FIXME: if b.is_finite() { Some(b) } else { None },
103- ( Some ( a) , b) => Some ( std:: cmp:: max_by ( a, b, |x, y| {
104- x. partial_cmp ( y) . unwrap_or ( std:: cmp:: Ordering :: Greater )
105- } ) ) ,
106- } ;
107- ( max, count + 1 )
108- } ) ;
109- let max = max. ok_or ( ActiveStorageError :: EmptyArray { operation : "max" } ) ?;
110- ( max, count)
156+
157+ let typed_missing: Option < Missing < T > > = if let Some ( missing) = & request_data. missing {
158+ let m = Missing :: try_from ( missing) ?;
159+ Some ( m)
111160 } else {
112- let max = * sliced. max ( ) . map_err ( |err| match err {
113- MinMaxError :: EmptyInput => ActiveStorageError :: EmptyArray { operation : "max" } ,
114- MinMaxError :: UndefinedOrder => panic ! ( "unexpected undefined order error for max" ) ,
115- } ) ?;
116- let count = sliced. len ( ) ;
117- ( max, count)
161+ None
118162 } ;
119- let count = i64:: try_from ( count) ?;
120- let body = max. as_bytes ( ) ;
121- // Need to copy to provide ownership to caller.
122- let body = Bytes :: copy_from_slice ( body) ;
123- Ok ( models:: Response :: new (
124- body,
125- request_data. dtype ,
126- vec ! [ ] ,
127- vec ! [ count] ,
128- ) )
163+
164+ match & request_data. axis {
165+ ReductionAxes :: One ( axis) => {
166+ let init = T :: min_value ( ) ;
167+ let result =
168+ sliced. fold_axis ( Axis ( * axis) , ( init, 0 ) , |( running_max, count) , val| {
169+ if let Some ( missing) = & typed_missing {
170+ if !missing. is_missing ( val) {
171+ ( * max_by ( running_max, val, max_element_pairwise) , count + 1 )
172+ } else {
173+ ( * running_max, * count)
174+ }
175+ } else {
176+ ( * max_by ( running_max, val, max_element_pairwise) , count + 1 )
177+ }
178+ } ) ;
179+ let maxes = result. iter ( ) . map ( |( max, _) | * max) . collect :: < Vec < T > > ( ) ;
180+ let counts = result. iter ( ) . map ( |( _, count) | * count) . collect :: < Vec < i64 > > ( ) ;
181+ let body = maxes. as_bytes ( ) ;
182+ let body = Bytes :: copy_from_slice ( body) ;
183+ Ok ( models:: Response :: new (
184+ body,
185+ request_data. dtype ,
186+ result. shape ( ) . into ( ) ,
187+ counts,
188+ ) )
189+ }
190+ ReductionAxes :: Multi ( axes) => {
191+ let ( maxes, counts, shape) = max_array_multi_axis ( sliced, axes, typed_missing) ;
192+ let body = Bytes :: copy_from_slice ( maxes. as_bytes ( ) ) ;
193+ Ok ( models:: Response :: new (
194+ body,
195+ request_data. dtype ,
196+ shape,
197+ counts,
198+ ) )
199+ }
200+ ReductionAxes :: All => {
201+ let init = T :: min_value ( ) ;
202+ let ( max, count) = sliced. fold ( ( init, 0_i64 ) , |( running_max, count) , val| {
203+ if let Some ( missing) = & typed_missing {
204+ if !missing. is_missing ( val) {
205+ ( * max_by ( & running_max, val, max_element_pairwise) , count + 1 )
206+ } else {
207+ ( running_max, count)
208+ }
209+ } else {
210+ ( * max_by ( & running_max, val, max_element_pairwise) , count + 1 )
211+ }
212+ } ) ;
213+
214+ let body = max. as_bytes ( ) ;
215+ let body = Bytes :: copy_from_slice ( body) ;
216+ Ok ( models:: Response :: new (
217+ body,
218+ request_data. dtype ,
219+ vec ! [ ] ,
220+ vec ! [ count] ,
221+ ) )
222+ }
223+ }
129224 }
130225}
131226
@@ -146,7 +241,7 @@ fn min_element_pairwise<T: Element>(x: &&T, y: &&T) -> std::cmp::Ordering {
146241 //
147242 // Which behaviour do we want to follow?
148243 //
149- // Panic is probably the best option for now...
244+ // Panic for now (TODO: Make this a user-facing error response instead)
150245 x. partial_cmp ( y)
151246 // .unwrap_or(std::cmp::Ordering::Less)
152247 . unwrap_or_else ( || panic ! ( "unexpected undefined order error for min" ) )
@@ -164,13 +259,13 @@ fn min_array_multi_axis<T: Element>(
164259 . fold_axis ( Axis ( axes[ 0 ] ) , ( init, 0 ) , |( running_min, count) , val| {
165260 if let Some ( missing) = & missing {
166261 if !missing. is_missing ( val) {
167- let new_min = std :: cmp :: min_by ( running_min, val, min_element_pairwise) ;
262+ let new_min = min_by ( running_min, val, min_element_pairwise) ;
168263 ( * new_min, count + 1 )
169264 } else {
170265 ( * running_min, * count)
171266 }
172267 } else {
173- let new_min = std :: cmp :: min_by ( running_min, val, min_element_pairwise) ;
268+ let new_min = min_by ( running_min, val, min_element_pairwise) ;
174269 ( * new_min, count + 1 )
175270 }
176271 } )
@@ -184,8 +279,7 @@ fn min_array_multi_axis<T: Element>(
184279 ( init, 0 ) ,
185280 |( global_min, total_count) , ( running_min, count) | {
186281 // (*global_min.min(running_min), total_count + count)
187- let new_min =
188- std:: cmp:: min_by ( global_min, running_min, min_element_pairwise) ;
282+ let new_min = min_by ( global_min, running_min, min_element_pairwise) ;
189283 ( * new_min, total_count + count)
190284 } ,
191285 )
@@ -1004,4 +1098,64 @@ mod tests {
10041098 assert_eq ! ( counts, vec![ 6 , 5 ] ) ;
10051099 assert_eq ! ( shape, vec![ 2 ] ) ;
10061100 }
1101+
1102+ #[ test]
1103+ #[ should_panic( expected = "assertion failed: axis.index() < self.ndim()" ) ]
1104+ fn max_multi_axis_2d_wrong_axis ( ) {
1105+ // Arrange
1106+ let array = ndarray:: Array :: from_shape_vec ( ( 2 , 2 ) , ( 0 ..4 ) . collect ( ) )
1107+ . unwrap ( )
1108+ . into_dyn ( ) ;
1109+ let axes = vec ! [ 2 ] ;
1110+ // Act
1111+ let _ = max_array_multi_axis ( array. view ( ) , & axes, None ) ;
1112+ }
1113+
1114+ #[ test]
1115+ fn max_multi_axis_2d_2ax ( ) {
1116+ // Arrange
1117+ let axes = vec ! [ 0 , 1 ] ;
1118+ let missing = None ;
1119+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
1120+ . unwrap ( )
1121+ . into_dyn ( ) ;
1122+ // Act
1123+ let ( result, counts, shape) = max_array_multi_axis ( arr. view ( ) , & axes, missing) ;
1124+ // Assert
1125+ assert_eq ! ( result, vec![ 5 ] ) ;
1126+ assert_eq ! ( counts, vec![ 6 ] ) ;
1127+ assert_eq ! ( shape, Vec :: <usize >:: new( ) ) ;
1128+ }
1129+
1130+ #[ test]
1131+ fn max_multi_axis_2d_1ax_missing ( ) {
1132+ // Arrange
1133+ let axes = vec ! [ 1 ] ;
1134+ let missing = Missing :: MissingValue ( 0 ) ;
1135+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
1136+ . unwrap ( )
1137+ . into_dyn ( ) ;
1138+ // Act
1139+ let ( result, counts, shape) = max_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
1140+ // Assert
1141+ assert_eq ! ( result, vec![ 2 , 5 ] ) ;
1142+ assert_eq ! ( counts, vec![ 2 , 3 ] ) ;
1143+ assert_eq ! ( shape, vec![ 2 ] ) ;
1144+ }
1145+
1146+ #[ test]
1147+ fn max_multi_axis_4d_3ax_missing ( ) {
1148+ // Arrange
1149+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 , 2 , 1 ) , ( 0 ..12 ) . collect ( ) )
1150+ . unwrap ( )
1151+ . into_dyn ( ) ;
1152+ let axes = vec ! [ 0 , 1 , 3 ] ;
1153+ let missing = Missing :: MissingValue ( 10 ) ;
1154+ // Act
1155+ let ( result, counts, shape) = max_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
1156+ // Assert
1157+ assert_eq ! ( result, vec![ 8 , 11 ] ) ;
1158+ assert_eq ! ( counts, vec![ 5 , 6 ] ) ;
1159+ assert_eq ! ( shape, vec![ 2 ] ) ;
1160+ }
10071161}
0 commit comments