33//! Each operation is implemented as a struct that implements the
44//! [Operation](crate::operation::Operation) trait.
55
6+ use std:: cmp:: min_by;
7+
68use crate :: array;
79use crate :: error:: ActiveStorageError ;
810use crate :: models:: { self , ReductionAxes } ;
@@ -130,6 +132,74 @@ impl NumOperation for Max {
130132/// Return the minimum of selected elements in the array.
131133pub struct Min { }
132134
135+ fn min_element_pairwise < T : Element > ( x : & & T , y : & & T ) -> std:: cmp:: Ordering {
136+ // TODO: How to handle NaN correctly?
137+ // Numpy seems to behave as follows:
138+ //
139+ // np.min([np.nan, 1]) == np.nan
140+ // np.max([np.nan, 1]) == np.nan
141+ // np.nan != np.nan
142+ // np.min([np.nan, 1]) != np.max([np.nan, 1])
143+ //
144+ // There are also separate np.nan{min,max} functions
145+ // which ignore nans instead.
146+ //
147+ // Which behaviour do we want to follow?
148+ //
149+ // Panic is probably the best option for now...
150+ x. partial_cmp ( y)
151+ // .unwrap_or(std::cmp::Ordering::Less)
152+ . unwrap_or_else ( || panic ! ( "unexpected undefined order error for min" ) )
153+ }
154+
155+ /// Finds the minimum value over one or more axes of the provided array
156+ fn min_array_multi_axis < T : Element > (
157+ array : ndarray:: ArrayView < T , ndarray:: IxDyn > ,
158+ axes : & [ usize ] ,
159+ missing : Option < Missing < T > > ,
160+ ) -> ( Vec < T > , Vec < i64 > , Vec < usize > ) {
161+ // Find minimum over first axis and count elements operated on
162+ let init = T :: max_value ( ) ;
163+ let mut result = array
164+ . fold_axis ( Axis ( axes[ 0 ] ) , ( init, 0 ) , |( running_min, count) , val| {
165+ if let Some ( missing) = & missing {
166+ if !missing. is_missing ( val) {
167+ let new_min = std:: cmp:: min_by ( running_min, val, min_element_pairwise) ;
168+ ( * new_min, count + 1 )
169+ } else {
170+ ( * running_min, * count)
171+ }
172+ } else {
173+ let new_min = std:: cmp:: min_by ( running_min, val, min_element_pairwise) ;
174+ ( * new_min, count + 1 )
175+ }
176+ } )
177+ . into_dyn ( ) ;
178+ // Find min over remaining axes (where total count is now sum of counts)
179+ if let Some ( remaining_axes) = axes. get ( 1 ..) {
180+ for ( n, axis) in remaining_axes. iter ( ) . enumerate ( ) {
181+ result = result
182+ . fold_axis (
183+ Axis ( axis - n - 1 ) ,
184+ ( init, 0 ) ,
185+ |( global_min, total_count) , ( running_min, count) | {
186+ // (*global_min.min(running_min), total_count + count)
187+ let new_min =
188+ std:: cmp:: min_by ( global_min, running_min, min_element_pairwise) ;
189+ ( * new_min, total_count + count)
190+ } ,
191+ )
192+ . into_dyn ( ) ;
193+ }
194+ }
195+
196+ // Result is array of (mins, count) tuples so separate them here
197+ let mins = result. iter ( ) . map ( |( min, _) | * min) . collect :: < Vec < T > > ( ) ;
198+ let counts = result. iter ( ) . map ( |( _, count) | * count) . collect :: < Vec < i64 > > ( ) ;
199+
200+ ( mins, counts, result. shape ( ) . into ( ) )
201+ }
202+
133203impl NumOperation for Min {
134204 fn execute_t < T : Element > (
135205 request_data : & models:: RequestData ,
@@ -138,44 +208,80 @@ impl NumOperation for Min {
138208 let array = array:: build_array :: < T > ( request_data, & mut data) ?;
139209 let slice_info = array:: build_slice_info :: < T > ( & request_data. selection , array. shape ( ) ) ;
140210 let sliced = array. slice ( slice_info) ;
141- let ( min, count) = if let Some ( missing) = & request_data. missing {
142- let missing = Missing :: < T > :: try_from ( missing) ?;
143- // Use a fold to simultaneously min and count the non-missing data.
144- // TODO: separate float impl?
145- // TODO: inifinite/NaN
146- let ( min, count) = sliced
147- . iter ( )
148- . copied ( )
149- . filter ( missing_filter ( & missing) )
150- . fold ( ( None , 0 ) , |( a, count) , b| {
151- let min = match ( a, b) {
152- ( None , b) => Some ( b) , //FIXME: if b.is_finite() { Some(b) } else { None },
153- ( Some ( a) , b) => Some ( std:: cmp:: min_by ( a, b, |x, y| {
154- x. partial_cmp ( y) . unwrap_or ( std:: cmp:: Ordering :: Less )
155- } ) ) ,
156- } ;
157- ( min, count + 1 )
158- } ) ;
159- let min = min. ok_or ( ActiveStorageError :: EmptyArray { operation : "min" } ) ?;
160- ( min, count)
211+
212+ // Convert Missing<Dtype> to Missing<T: Element>
213+ let typed_missing: Option < Missing < T > > = if let Some ( missing) = & request_data. missing {
214+ let m = Missing :: try_from ( missing) ?;
215+ Some ( m)
161216 } else {
162- let min = * sliced. min ( ) . map_err ( |err| match err {
163- MinMaxError :: EmptyInput => ActiveStorageError :: EmptyArray { operation : "min" } ,
164- MinMaxError :: UndefinedOrder => panic ! ( "unexpected undefined order error for min" ) ,
165- } ) ?;
166- let count = sliced. len ( ) ;
167- ( min, count)
217+ None
168218 } ;
169- let count = i64:: try_from ( count) ?;
170- let body = min. as_bytes ( ) ;
171- // Need to copy to provide ownership to caller.
172- let body = Bytes :: copy_from_slice ( body) ;
173- Ok ( models:: Response :: new (
174- body,
175- request_data. dtype ,
176- vec ! [ ] ,
177- vec ! [ count] ,
178- ) )
219+
220+ // Use ndarray::fold, ndarray::fold_axis or dispatch to specialised
221+ // multi-axis function depending on whether we're performing reduction
222+ // over all axes or only a subset
223+ match & request_data. axes {
224+ ReductionAxes :: One ( axis) => {
225+ let init = T :: max_value ( ) ;
226+ let result =
227+ sliced. fold_axis ( Axis ( * axis) , ( init, 0 ) , |( running_min, count) , val| {
228+ if let Some ( missing) = & typed_missing {
229+ if !missing. is_missing ( val) {
230+ ( * min_by ( running_min, val, min_element_pairwise) , count + 1 )
231+ } else {
232+ ( * running_min, * count)
233+ }
234+ } else {
235+ ( * min_by ( running_min, val, min_element_pairwise) , count + 1 )
236+ }
237+ } ) ;
238+ // Unpack the result tuples into separate vectors
239+ let mins = result. iter ( ) . map ( |( min, _) | * min) . collect :: < Vec < T > > ( ) ;
240+ let counts = result. iter ( ) . map ( |( _, count) | * count) . collect :: < Vec < i64 > > ( ) ;
241+ let body = mins. as_bytes ( ) ;
242+ let body = Bytes :: copy_from_slice ( body) ;
243+ Ok ( models:: Response :: new (
244+ body,
245+ request_data. dtype ,
246+ result. shape ( ) . into ( ) ,
247+ counts,
248+ ) )
249+ }
250+ ReductionAxes :: Multi ( axes) => {
251+ let ( mins, counts, shape) = min_array_multi_axis ( sliced, axes, typed_missing) ;
252+ let body = Bytes :: copy_from_slice ( mins. as_bytes ( ) ) ;
253+ Ok ( models:: Response :: new (
254+ body,
255+ request_data. dtype ,
256+ shape,
257+ counts,
258+ ) )
259+ }
260+ ReductionAxes :: All => {
261+ let init = T :: max_value ( ) ;
262+ let ( min, count) = sliced. fold ( ( init, 0_i64 ) , |( running_min, count) , val| {
263+ if let Some ( missing) = & typed_missing {
264+ if !missing. is_missing ( val) {
265+ ( * min_by ( & running_min, val, min_element_pairwise) , count + 1 )
266+ } else {
267+ ( running_min, count)
268+ }
269+ } else {
270+ ( * min_by ( & running_min, val, min_element_pairwise) , count + 1 )
271+ }
272+ } ) ;
273+
274+ let body = min. as_bytes ( ) ;
275+ // Need to copy to provide ownership to caller.
276+ let body = Bytes :: copy_from_slice ( body) ;
277+ Ok ( models:: Response :: new (
278+ body,
279+ request_data. dtype ,
280+ vec ! [ ] ,
281+ vec ! [ count] ,
282+ ) )
283+ }
284+ }
179285 }
180286}
181287
@@ -220,6 +326,7 @@ impl NumOperation for Select {
220326/// Return the sum of selected elements in the array.
221327pub struct Sum { }
222328
329+ /// Performs a sum over one or more axes of the provided array
223330fn sum_array_multi_axis < T : Element > (
224331 array : ndarray:: ArrayView < T , ndarray:: IxDyn > ,
225332 axes : & [ usize ] ,
@@ -551,6 +658,7 @@ mod tests {
551658 }
552659
553660 #[ test]
661+ #[ should_panic( expected = "unexpected undefined order error for min" ) ]
554662 fn min_f32_1d_nan_missing_value ( ) {
555663 let mut request_data = test_utils:: get_test_request_data ( ) ;
556664 request_data. dtype = models:: DType :: Float32 ;
@@ -567,6 +675,7 @@ mod tests {
567675 }
568676
569677 #[ test]
678+ #[ should_panic( expected = "unexpected undefined order error for min" ) ]
570679 fn min_f32_1d_nan_first_missing_value ( ) {
571680 let mut request_data = test_utils:: get_test_request_data ( ) ;
572681 request_data. dtype = models:: DType :: Float32 ;
@@ -783,7 +892,7 @@ mod tests {
783892
784893 #[ test]
785894 #[ should_panic( expected = "assertion failed: axis.index() < self.ndim()" ) ]
786- fn test_sum_multi_axis_2d_wrong_axis ( ) {
895+ fn sum_multi_axis_2d_wrong_axis ( ) {
787896 let array = ndarray:: Array :: from_shape_vec ( ( 2 , 2 ) , ( 0 ..4 ) . collect ( ) )
788897 . unwrap ( )
789898 . into_dyn ( ) ;
@@ -792,7 +901,7 @@ mod tests {
792901 }
793902
794903 #[ test]
795- fn test_sum_multi_axis_2d_2ax ( ) {
904+ fn sum_multi_axis_2d_2ax ( ) {
796905 let array = ndarray:: Array :: from_shape_vec ( ( 2 , 2 ) , ( 0 ..4 ) . collect ( ) )
797906 . unwrap ( )
798907 . into_dyn ( ) ;
@@ -804,7 +913,7 @@ mod tests {
804913 }
805914
806915 #[ test]
807- fn test_sum_multi_axis_2d_2ax_missing ( ) {
916+ fn sum_multi_axis_2d_2ax_missing ( ) {
808917 let array = ndarray:: Array :: from_shape_vec ( ( 2 , 2 ) , ( 0 ..4 ) . collect ( ) )
809918 . unwrap ( )
810919 . into_dyn ( ) ;
@@ -817,7 +926,7 @@ mod tests {
817926 }
818927
819928 #[ test]
820- fn test_sum_multi_axis_4d_1ax ( ) {
929+ fn sum_multi_axis_4d_1ax ( ) {
821930 let array = ndarray:: Array :: from_shape_vec ( ( 2 , 3 , 2 , 1 ) , ( 0 ..12 ) . collect ( ) )
822931 . unwrap ( )
823932 . into_dyn ( ) ;
@@ -829,7 +938,7 @@ mod tests {
829938 }
830939
831940 #[ test]
832- fn test_sum_multi_axis_4d_3ax ( ) {
941+ fn sum_multi_axis_4d_3ax ( ) {
833942 let array = ndarray:: Array :: from_shape_vec ( ( 2 , 3 , 2 , 1 ) , ( 0 ..12 ) . collect ( ) )
834943 . unwrap ( )
835944 . into_dyn ( ) ;
@@ -839,4 +948,60 @@ mod tests {
839948 assert_eq ! ( count, vec![ 6 , 6 ] ) ;
840949 assert_eq ! ( shape, vec![ 2 ] ) ;
841950 }
951+
952+ #[ test]
953+ #[ should_panic( expected = "assertion failed: axis.index() < self.ndim()" ) ]
954+ fn min_multi_axis_2d_wrong_axis ( ) {
955+ let array = ndarray:: Array :: from_shape_vec ( ( 2 , 2 ) , ( 0 ..4 ) . collect ( ) )
956+ . unwrap ( )
957+ . into_dyn ( ) ;
958+ let axes = vec ! [ 2 ] ;
959+ let _ = min_array_multi_axis ( array. view ( ) , & axes, None ) ;
960+ }
961+
962+ #[ test]
963+ fn min_multi_axis_2d_2ax ( ) {
964+ // Arrrange
965+ let axes = vec ! [ 0 , 1 ] ;
966+ let missing = None ;
967+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
968+ . unwrap ( )
969+ . into_dyn ( ) ;
970+ // Act
971+ let ( result, counts, shape) = min_array_multi_axis ( arr. view ( ) , & axes, missing) ;
972+ // Assert
973+ assert_eq ! ( result, vec![ 0 ] ) ;
974+ assert_eq ! ( counts, vec![ 6 ] ) ;
975+ assert_eq ! ( shape, Vec :: <usize >:: new( ) ) ;
976+ }
977+
978+ #[ test]
979+ fn min_multi_axis_2d_1ax_missing ( ) {
980+ // Arrange
981+ let axes = vec ! [ 1 ] ;
982+ let missing = Missing :: MissingValue ( 0 ) ;
983+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
984+ . unwrap ( )
985+ . into_dyn ( ) ;
986+ // Act
987+ let ( result, counts, shape) = min_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
988+ // Assert
989+ assert_eq ! ( result, vec![ 1 , 3 ] ) ;
990+ assert_eq ! ( counts, vec![ 2 , 3 ] ) ;
991+ assert_eq ! ( shape, vec![ 2 ] ) ;
992+ }
993+
994+ #[ test]
995+ fn min_multi_axis_4d_3ax_missing ( ) {
996+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 , 2 , 1 ) , ( 0 ..12 ) . collect ( ) )
997+ . unwrap ( )
998+ . into_dyn ( ) ;
999+ let axes = vec ! [ 0 , 1 , 3 ] ;
1000+ let missing = Missing :: MissingValue ( 1 ) ;
1001+ let ( result, counts, shape) = min_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
1002+
1003+ assert_eq ! ( result, vec![ 0 , 3 ] ) ;
1004+ assert_eq ! ( counts, vec![ 6 , 5 ] ) ;
1005+ assert_eq ! ( shape, vec![ 2 ] ) ;
1006+ }
8421007}
0 commit comments