Skip to content

Commit 9cfcd6d

Browse files
authored
Patches to use native search_sorted (#3245)
Instead of dispatching through compute function
1 parent 417e06d commit 9cfcd6d

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

vortex-array/src/patches.rs

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::fmt::Debug;
33
use std::hash::Hash;
44

55
use itertools::Itertools as _;
6-
use num_traits::ToPrimitive;
6+
use num_traits::{NumCast, ToPrimitive};
77
use serde::{Deserialize, Serialize};
88
use vortex_buffer::BufferMut;
99
use vortex_dtype::Nullability::NonNullable;
@@ -15,7 +15,7 @@ use vortex_scalar::Scalar;
1515
use crate::aliases::hash_map::HashMap;
1616
use crate::arrays::PrimitiveArray;
1717
use crate::compute::{
18-
SearchResult, SearchSortedSide, cast, filter, search_sorted, search_sorted_many, take,
18+
SearchResult, SearchSorted, SearchSortedSide, cast, filter, search_sorted, take,
1919
};
2020
use crate::variants::PrimitiveArrayTrait;
2121
use crate::{Array, ArrayRef, IntoArray, ToCanonical};
@@ -321,10 +321,13 @@ impl Patches {
321321
}
322322

323323
pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
324+
let indices = self.indices.to_primitive()?;
324325
let new_length = take_indices.len();
325326

326-
let Some((new_indices, values_indices)) = match_each_integer_ptype!(take_indices.ptype(), |$I| {
327-
take_search::<$I>(self.indices(), take_indices, self.offset())?
327+
let Some((new_indices, values_indices)) = match_each_integer_ptype!(indices.ptype(), |$INDICES| {
328+
match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
329+
take_search::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset())?
330+
})
328331
}) else {
329332
return Ok(None);
330333
};
@@ -373,8 +376,8 @@ impl Patches {
373376
}
374377
}
375378

376-
fn take_search<T: NativePType + TryFrom<usize>>(
377-
indices: &dyn Array,
379+
fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
380+
indices: &[I],
378381
take_indices: PrimitiveArray,
379382
indices_offset: usize,
380383
) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
@@ -383,24 +386,27 @@ where
383386
VortexError: From<<usize as TryFrom<T>>::Error>,
384387
{
385388
let take_indices_validity = take_indices.validity();
386-
let take_indices = take_indices
389+
let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
390+
391+
let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
387392
.as_slice::<T>()
388393
.iter()
389-
.copied()
390-
.map(usize::try_from)
391-
.map_ok(|idx| idx + indices_offset)
392-
.collect::<Result<Vec<_>, _>>()?;
393-
394-
let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) =
395-
search_sorted_many(indices, &take_indices, SearchSortedSide::Left)?
396-
.iter()
397-
.enumerate()
398-
.filter_map(|(idx_in_take, search_result)| {
399-
search_result
400-
.to_found()
401-
.map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
402-
})
403-
.unzip();
394+
.map(|v| {
395+
match I::from(*v) {
396+
None => {
397+
// If the cast failed, then the value is greater than all indices.
398+
SearchResult::NotFound(indices.len())
399+
}
400+
Some(v) => indices.search_sorted(&(v + indices_offset), SearchSortedSide::Left),
401+
}
402+
})
403+
.enumerate()
404+
.filter_map(|(idx_in_take, search_result)| {
405+
search_result
406+
.to_found()
407+
.map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
408+
})
409+
.unzip();
404410

405411
if new_indices.is_empty() {
406412
return Ok(None);

0 commit comments

Comments
 (0)