@@ -7,6 +7,7 @@ mod fill_null;
77mod filter;
88mod mask;
99mod min_max;
10+ mod scalar_at;
1011mod search_sorted;
1112mod slice;
1213mod sort;
@@ -24,6 +25,7 @@ use libfuzzer_sys::arbitrary::Error::EmptyChoose;
2425use libfuzzer_sys:: arbitrary:: { Arbitrary , Unstructured } ;
2526pub ( crate ) use mask:: * ;
2627pub ( crate ) use min_max:: * ;
28+ pub ( crate ) use scalar_at:: * ;
2729pub ( crate ) use search_sorted:: * ;
2830pub ( crate ) use slice:: * ;
2931pub use sort:: sort_canonical_array;
@@ -80,6 +82,8 @@ pub enum Action {
8082 MinMax ,
8183 FillNull ( Scalar ) ,
8284 Mask ( Mask ) ,
85+ // Here we want to try multiple values.
86+ ScalarAt ( Vec < usize > ) ,
8387}
8488
8589#[ derive( Debug ) ]
@@ -88,6 +92,7 @@ pub enum ExpectedValue {
8892 Search ( SearchResult ) ,
8993 Scalar ( Scalar ) ,
9094 MinMax ( Option < MinMaxResult > ) ,
95+ ScalarVec ( Vec < Scalar > ) ,
9196}
9297
9398impl ExpectedValue {
@@ -118,6 +123,13 @@ impl ExpectedValue {
118123 _ => vortex_panic ! ( "expected min_max" ) ,
119124 }
120125 }
126+
127+ pub fn scalar_vec ( self ) -> Vec < Scalar > {
128+ match self {
129+ ExpectedValue :: ScalarVec ( v) => v,
130+ _ => vortex_panic ! ( "expected scalar_vec" ) ,
131+ }
132+ }
121133}
122134
123135impl < ' a > Arbitrary < ' a > for FuzzArrayAction {
@@ -131,7 +143,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
131143 valid_actions. sort_unstable_by_key ( |a| * a as usize ) ;
132144
133145 let mut actions = Vec :: new ( ) ;
134- let action_count = u. int_in_range ( 1 ..=4 ) ?;
146+ let action_count = u. int_in_range ( 1 ..=4 . min ( valid_actions . len ( ) ) ) ?;
135147 for _ in 0 ..action_count {
136148 let action_type = random_action_from_list ( u, valid_actions. as_slice ( ) ) ?;
137149
@@ -313,6 +325,35 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
313325 ExpectedValue :: Array ( expected_result) ,
314326 )
315327 }
328+ ActionType :: ScalarAt => {
329+ if current_array. is_empty ( ) {
330+ return Err ( EmptyChoose ) ;
331+ }
332+
333+ let num_indices = u. int_in_range ( 1 ..=5 . min ( current_array. len ( ) ) ) ?;
334+ let mut indices = HashSet :: with_capacity ( num_indices) ;
335+
336+ while indices. len ( ) < num_indices {
337+ let idx = u. choose_index ( current_array. len ( ) ) ?;
338+ indices. insert ( idx) ;
339+ }
340+
341+ let indices_vec: Vec < usize > = indices. into_iter ( ) . collect ( ) ;
342+
343+ // Compute expected scalars using the baseline implementation
344+ let expected_scalars: Vec < Scalar > = indices_vec
345+ . iter ( )
346+ . map ( |& idx| {
347+ scalar_at_canonical_array ( current_array. to_canonical ( ) , idx)
348+ . vortex_unwrap ( )
349+ } )
350+ . collect ( ) ;
351+
352+ (
353+ Action :: ScalarAt ( indices_vec) ,
354+ ExpectedValue :: ScalarVec ( expected_scalars) ,
355+ )
356+ }
316357 } )
317358 }
318359
@@ -325,23 +366,23 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
325366
326367 match dtype {
327368 DType :: Struct ( sdt, _) => {
328- // Struct supports: Compress, Slice, Take, Filter, MinMax, Mask
369+ // Struct supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
329370 // Does NOT support: SearchSorted (requires scalar comparison), Compare, Cast, Sum, FillNull
330- let struct_actions = [ Compress , Slice , Take , Filter , MinMax , Mask ] ;
371+ let struct_actions = [ Compress , Slice , Take , Filter , MinMax , Mask , ScalarAt ] ;
331372 sdt. fields ( )
332373 . map ( |child| actions_for_dtype ( & child) )
333374 . fold ( struct_actions. into ( ) , |acc, actions| {
334375 acc. intersection ( & actions) . copied ( ) . collect ( )
335376 } )
336377 }
337378 DType :: List ( ..) | DType :: FixedSizeList ( ..) => {
338- // List supports: Compress, Slice, Take, Filter, MinMax, Mask
379+ // List supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
339380 // Does NOT support: SearchSorted, Compare, Cast, Sum, FillNull
340- [ Compress , Slice , Take , Filter , MinMax , Mask ] . into ( )
381+ [ Compress , Slice , Take , Filter , MinMax , Mask , ScalarAt ] . into ( )
341382 }
342383 DType :: Utf8 ( _) | DType :: Binary ( _) => {
343384 // Utf8/Binary supports everything except Sum
344- // Actions: Compress, Slice, Take, SearchSorted, Filter, Compare, Cast, MinMax, FillNull, Mask
385+ // Actions: Compress, Slice, Take, SearchSorted, Filter, Compare, Cast, MinMax, FillNull, Mask, ScalarAt
345386 [
346387 Compress ,
347388 Slice ,
@@ -353,6 +394,7 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
353394 MinMax ,
354395 FillNull ,
355396 Mask ,
397+ ScalarAt ,
356398 ]
357399 . into ( )
358400 }
@@ -372,6 +414,7 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
372414 Cast ,
373415 FillNull ,
374416 Mask ,
417+ ScalarAt ,
375418 ]
376419 . into ( )
377420 }
0 commit comments