@@ -3,7 +3,7 @@ use std::fmt::Debug;
33use std:: hash:: Hash ;
44
55use itertools:: Itertools as _;
6- use num_traits:: ToPrimitive ;
6+ use num_traits:: { NumCast , ToPrimitive } ;
77use serde:: { Deserialize , Serialize } ;
88use vortex_buffer:: BufferMut ;
99use vortex_dtype:: Nullability :: NonNullable ;
@@ -15,7 +15,7 @@ use vortex_scalar::Scalar;
1515use crate :: aliases:: hash_map:: HashMap ;
1616use crate :: arrays:: PrimitiveArray ;
1717use crate :: compute:: {
18- SearchResult , SearchSortedSide , cast, filter, search_sorted, search_sorted_many , take,
18+ SearchResult , SearchSorted , SearchSortedSide , cast, filter, search_sorted, take,
1919} ;
2020use crate :: variants:: PrimitiveArrayTrait ;
2121use crate :: { Array , ArrayRef , IntoArray , ToCanonical } ;
@@ -321,10 +321,13 @@ impl Patches {
321321 }
322322
323323 pub fn take_search ( & self , take_indices : PrimitiveArray ) -> VortexResult < Option < Self > > {
324+ let indices = self . indices . to_primitive ( ) ?;
324325 let new_length = take_indices. len ( ) ;
325326
326- let Some ( ( new_indices, values_indices) ) = match_each_integer_ptype ! ( take_indices. ptype( ) , |$I | {
327- take_search:: <$I >( self . indices( ) , take_indices, self . offset( ) ) ?
327+ let Some ( ( new_indices, values_indices) ) = match_each_integer_ptype ! ( indices. ptype( ) , |$INDICES | {
328+ match_each_integer_ptype!( take_indices. ptype( ) , |$TAKE_INDICES | {
329+ take_search:: <_, $TAKE_INDICES >( indices. as_slice:: <$INDICES >( ) , take_indices, self . offset( ) ) ?
330+ } )
328331 } ) else {
329332 return Ok ( None ) ;
330333 } ;
@@ -373,8 +376,8 @@ impl Patches {
373376 }
374377}
375378
376- fn take_search < T : NativePType + TryFrom < usize > > (
377- indices : & dyn Array ,
379+ fn take_search < I : NativePType + NumCast + PartialOrd , T : NativePType + NumCast > (
380+ indices : & [ I ] ,
378381 take_indices : PrimitiveArray ,
379382 indices_offset : usize ,
380383) -> VortexResult < Option < ( ArrayRef , ArrayRef ) > >
@@ -383,24 +386,27 @@ where
383386 VortexError : From < <usize as TryFrom < T > >:: Error > ,
384387{
385388 let take_indices_validity = take_indices. validity ( ) ;
386- let take_indices = take_indices
389+ let indices_offset = I :: from ( indices_offset) . vortex_expect ( "indices_offset out of range" ) ;
390+
391+ let ( values_indices, new_indices) : ( BufferMut < u64 > , BufferMut < u64 > ) = take_indices
387392 . as_slice :: < T > ( )
388393 . iter ( )
389- . copied ( )
390- . map ( usize:: try_from)
391- . map_ok ( |idx| idx + indices_offset)
392- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
393-
394- let ( values_indices, new_indices) : ( BufferMut < u64 > , BufferMut < u64 > ) =
395- search_sorted_many ( indices, & take_indices, SearchSortedSide :: Left ) ?
396- . iter ( )
397- . enumerate ( )
398- . filter_map ( |( idx_in_take, search_result) | {
399- search_result
400- . to_found ( )
401- . map ( |patch_idx| ( patch_idx as u64 , idx_in_take as u64 ) )
402- } )
403- . unzip ( ) ;
394+ . map ( |v| {
395+ match I :: from ( * v) {
396+ None => {
397+ // If the cast failed, then the value is greater than all indices.
398+ SearchResult :: NotFound ( indices. len ( ) )
399+ }
400+ Some ( v) => indices. search_sorted ( & ( v + indices_offset) , SearchSortedSide :: Left ) ,
401+ }
402+ } )
403+ . enumerate ( )
404+ . filter_map ( |( idx_in_take, search_result) | {
405+ search_result
406+ . to_found ( )
407+ . map ( |patch_idx| ( patch_idx as u64 , idx_in_take as u64 ) )
408+ } )
409+ . unzip ( ) ;
404410
405411 if new_indices. is_empty ( ) {
406412 return Ok ( None ) ;
0 commit comments