Skip to content

Commit ca6e648

Browse files
authored
SearchSorted Many Side (#1427)
Fixes #1413
1 parent df11488 commit ca6e648

File tree

3 files changed

+20
-39
lines changed

3 files changed

+20
-39
lines changed

encodings/fastlanes/src/bitpacking/compute/search_sorted.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,16 @@ impl SearchSortedFn<BitPackedArray> for BitPackedEncoding {
5353
&self,
5454
array: &BitPackedArray,
5555
values: &[Scalar],
56-
sides: &[SearchSortedSide],
56+
side: SearchSortedSide,
5757
) -> VortexResult<Vec<SearchResult>> {
5858
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
5959
let searcher = BitPackedSearch::<'_, $P>::new(array);
6060

6161
values
6262
.iter()
63-
.zip(sides.iter().copied())
64-
.map(|(value, side)| {
63+
.map(|value| {
6564
// Unwrap to native value
6665
let unwrapped_value: $P = value.cast(array.dtype())?.try_into()?;
67-
6866
Ok(searcher.search_sorted(&unwrapped_value, side))
6967
})
7068
.try_collect()
@@ -75,16 +73,14 @@ impl SearchSortedFn<BitPackedArray> for BitPackedEncoding {
7573
&self,
7674
array: &BitPackedArray,
7775
values: &[usize],
78-
sides: &[SearchSortedSide],
76+
side: SearchSortedSide,
7977
) -> VortexResult<Vec<SearchResult>> {
8078
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
8179
let searcher = BitPackedSearch::<'_, $P>::new(array);
8280

8381
values
8482
.iter()
85-
.copied()
86-
.zip(sides.iter().copied())
87-
.map(|(value, side)| {
83+
.map(|&value| {
8884
// NOTE: truncating cast
8985
let cast_value: $P = value as $P;
9086
Ok(searcher.search_sorted(&cast_value, side))
@@ -299,11 +295,7 @@ mod test {
299295
let results = search_sorted_many(
300296
bitpacked.as_ref(),
301297
&[3u64, 2u64, 1u64],
302-
&[
303-
SearchSortedSide::Left,
304-
SearchSortedSide::Left,
305-
SearchSortedSide::Left,
306-
],
298+
SearchSortedSide::Left,
307299
)
308300
.unwrap();
309301

encodings/runend/src/array.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,7 @@ impl RunEndArray {
123123
///
124124
/// See: [find_physical_index][Self::find_physical_index].
125125
pub fn find_physical_indices(&self, indices: &[usize]) -> VortexResult<Vec<usize>> {
126-
search_sorted_usize_many(
127-
&self.ends(),
128-
indices,
129-
&vec![SearchSortedSide::Right; indices.len()],
130-
)
131-
.map(|results| {
126+
search_sorted_usize_many(&self.ends(), indices, SearchSortedSide::Right).map(|results| {
132127
results
133128
.iter()
134129
.map(|result| result.to_ends_index(self.ends().len()))

vortex-array/src/compute/search_sorted.rs

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,26 +121,23 @@ pub trait SearchSortedFn<Array> {
121121
&self,
122122
array: &Array,
123123
values: &[Scalar],
124-
sides: &[SearchSortedSide],
124+
side: SearchSortedSide,
125125
) -> VortexResult<Vec<SearchResult>> {
126126
values
127127
.iter()
128-
.zip(sides.iter())
129-
.map(|(value, side)| self.search_sorted(array, value, *side))
128+
.map(|value| self.search_sorted(array, value, side))
130129
.try_collect()
131130
}
132131

133132
fn search_sorted_usize_many(
134133
&self,
135134
array: &Array,
136135
values: &[usize],
137-
sides: &[SearchSortedSide],
136+
side: SearchSortedSide,
138137
) -> VortexResult<Vec<SearchResult>> {
139138
values
140139
.iter()
141-
.copied()
142-
.zip(sides.iter().copied())
143-
.map(|(value, side)| self.search_sorted_usize(array, value, side))
140+
.map(|&value| self.search_sorted_usize(array, value, side))
144141
.try_collect()
145142
}
146143
}
@@ -184,30 +181,30 @@ where
184181
&self,
185182
array: &ArrayData,
186183
values: &[Scalar],
187-
sides: &[SearchSortedSide],
184+
side: SearchSortedSide,
188185
) -> VortexResult<Vec<SearchResult>> {
189186
let array_ref = <&E::Array>::try_from(array)?;
190187
let encoding = array
191188
.encoding()
192189
.as_any()
193190
.downcast_ref::<E>()
194191
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
195-
SearchSortedFn::search_sorted_many(encoding, array_ref, values, sides)
192+
SearchSortedFn::search_sorted_many(encoding, array_ref, values, side)
196193
}
197194

198195
fn search_sorted_usize_many(
199196
&self,
200197
array: &ArrayData,
201198
values: &[usize],
202-
sides: &[SearchSortedSide],
199+
side: SearchSortedSide,
203200
) -> VortexResult<Vec<SearchResult>> {
204201
let array_ref = <&E::Array>::try_from(array)?;
205202
let encoding = array
206203
.encoding()
207204
.as_any()
208205
.downcast_ref::<E>()
209206
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
210-
SearchSortedFn::search_sorted_usize_many(encoding, array_ref, values, sides)
207+
SearchSortedFn::search_sorted_usize_many(encoding, array_ref, values, side)
211208
}
212209
}
213210

@@ -261,41 +258,38 @@ pub fn search_sorted_usize(
261258
pub fn search_sorted_many<T: Into<Scalar> + Clone>(
262259
array: &ArrayData,
263260
targets: &[T],
264-
sides: &[SearchSortedSide],
261+
side: SearchSortedSide,
265262
) -> VortexResult<Vec<SearchResult>> {
266263
if let Some(f) = array.encoding().search_sorted_fn() {
267264
let values: Vec<Scalar> = targets
268265
.iter()
269266
.map(|t| t.clone().into().cast(array.dtype()))
270267
.try_collect()?;
271268

272-
return f.search_sorted_many(array, &values, sides);
269+
return f.search_sorted_many(array, &values, side);
273270
}
274271

275272
// Call in loop and collect
276273
targets
277274
.iter()
278-
.zip(sides.iter().copied())
279-
.map(|(target, side)| search_sorted(array, target.clone(), side))
275+
.map(|target| search_sorted(array, target.clone(), side))
280276
.try_collect()
281277
}
282278

283279
// Native functions for each of the values, cast up to u64 or down to something lower.
284280
pub fn search_sorted_usize_many(
285281
array: &ArrayData,
286282
targets: &[usize],
287-
sides: &[SearchSortedSide],
283+
side: SearchSortedSide,
288284
) -> VortexResult<Vec<SearchResult>> {
289285
if let Some(f) = array.encoding().search_sorted_fn() {
290-
return f.search_sorted_usize_many(array, targets, sides);
286+
return f.search_sorted_usize_many(array, targets, side);
291287
}
292288

293289
// Call in loop and collect
294290
targets
295291
.iter()
296-
.copied()
297-
.zip(sides.iter().copied())
298-
.map(|(target, side)| search_sorted_usize(array, target, side))
292+
.map(|&target| search_sorted_usize(array, target, side))
299293
.try_collect()
300294
}
301295

0 commit comments

Comments
 (0)