Skip to content

Commit e8cd434

Browse files
authored
Fix regression in search_sorted when Patches replaced SparseArray (#1624)
1 parent 4edfc74 commit e8cd434

File tree

4 files changed

+113
-20
lines changed

4 files changed

+113
-20
lines changed

encodings/fastlanes/src/bitpacking/compute/search_sorted.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,23 @@ mod test {
273273
assert_eq!(found, SearchResult::Found(0));
274274
}
275275

276+
#[test]
277+
fn test_search_sorted_nulls_not_found() {
278+
let bitpacked = BitPackedArray::encode(
279+
PrimitiveArray::from_nullable_vec(vec![Some(0u8), Some(107u8), None, None]).as_ref(),
280+
0,
281+
)
282+
.unwrap();
283+
284+
let found = search_sorted(
285+
bitpacked.as_ref(),
286+
Scalar::primitive(127u8, Nullability::Nullable),
287+
SearchSortedSide::Left,
288+
)
289+
.unwrap();
290+
assert_eq!(found, SearchResult::NotFound(2));
291+
}
292+
276293
#[test]
277294
fn test_search_sorted_many() {
278295
// Test search_sorted_many with an array that contains several null values.

vortex-array/src/patches.rs

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,21 @@ impl Patches {
137137
target: T,
138138
side: SearchSortedSide,
139139
) -> VortexResult<SearchResult> {
140-
Ok(match search_sorted(self.values(), target.into(), side)? {
141-
SearchResult::Found(idx) => SearchResult::Found(if idx == self.indices().len() {
142-
self.array_len()
143-
} else {
144-
usize::try_from(&scalar_at(self.indices(), idx)?)?
145-
}),
146-
SearchResult::NotFound(idx) => SearchResult::NotFound(if idx == self.indices().len() {
147-
self.array_len()
148-
} else {
149-
usize::try_from(&scalar_at(self.indices(), idx)?)?
150-
}),
140+
search_sorted(self.values(), target.into(), side).and_then(|sr| {
141+
let sidx = sr.to_offsets_index(self.indices().len());
142+
let index = usize::try_from(&scalar_at(self.indices(), sidx)?)?;
143+
Ok(match sr {
144+
// If we reached the end of patched values when searching then the result is one after the last patch index
145+
SearchResult::Found(i) => SearchResult::Found(if i == self.indices().len() {
146+
index + 1
147+
} else {
148+
index
149+
}),
150+
// If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
151+
SearchResult::NotFound(i) => {
152+
SearchResult::NotFound(if i == 0 { index } else { index + 1 })
153+
}
154+
})
151155
})
152156
}
153157

@@ -260,9 +264,12 @@ impl Patches {
260264

261265
#[cfg(test)]
262266
mod test {
267+
use rstest::{fixture, rstest};
268+
263269
use crate::array::PrimitiveArray;
264-
use crate::compute::FilterMask;
270+
use crate::compute::{FilterMask, SearchResult, SearchSortedSide};
265271
use crate::patches::Patches;
272+
use crate::validity::Validity;
266273
use crate::{IntoArrayData, IntoArrayVariant};
267274

268275
#[test]
@@ -283,4 +290,64 @@ mod test {
283290
assert_eq!(indices.maybe_null_slice::<u64>(), &[0, 1]);
284291
assert_eq!(values.maybe_null_slice::<i32>(), &[100, 200]);
285292
}
293+
294+
#[fixture]
295+
fn patches() -> Patches {
296+
Patches::new(
297+
20,
298+
PrimitiveArray::from(vec![2u64, 9, 15]).into_array(),
299+
PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(),
300+
)
301+
}
302+
303+
#[rstest]
304+
fn search_larger_than(patches: Patches) {
305+
let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
306+
assert_eq!(res, SearchResult::NotFound(16));
307+
}
308+
309+
#[rstest]
310+
fn search_less_than(patches: Patches) {
311+
let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
312+
assert_eq!(res, SearchResult::NotFound(2));
313+
}
314+
315+
#[rstest]
316+
fn search_found(patches: Patches) {
317+
let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
318+
assert_eq!(res, SearchResult::Found(9));
319+
}
320+
321+
#[rstest]
322+
fn search_not_found_right(patches: Patches) {
323+
let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
324+
assert_eq!(res, SearchResult::NotFound(16));
325+
}
326+
327+
#[rstest]
328+
fn search_sliced(patches: Patches) {
329+
let sliced = patches.slice(7, 20).unwrap().unwrap();
330+
assert_eq!(
331+
sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
332+
SearchResult::NotFound(2)
333+
);
334+
}
335+
336+
#[test]
337+
fn search_right() {
338+
let patches = Patches::new(
339+
2,
340+
PrimitiveArray::from(vec![0u64]).into_array(),
341+
PrimitiveArray::from_vec(vec![0u8], Validity::AllValid).into_array(),
342+
);
343+
344+
assert_eq!(
345+
patches.search_sorted(0, SearchSortedSide::Right).unwrap(),
346+
SearchResult::Found(1)
347+
);
348+
assert_eq!(
349+
patches.search_sorted(1, SearchSortedSide::Right).unwrap(),
350+
SearchResult::NotFound(1)
351+
);
352+
}
286353
}

vortex-sampling-compressor/src/compressors/alp.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
use vortex_alp::{alp_encode_components, match_each_alp_float_ptype, ALPArray, ALPEncoding};
1+
use vortex_alp::{
2+
alp_encode_components, match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPRDEncoding,
3+
};
24
use vortex_array::aliases::hash_set::HashSet;
35
use vortex_array::array::PrimitiveArray;
46
use vortex_array::encoding::{Encoding, EncodingRef};
57
use vortex_array::variants::PrimitiveArrayTrait;
68
use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant};
79
use vortex_dtype::PType;
810
use vortex_error::VortexResult;
11+
use vortex_fastlanes::BitPackedEncoding;
912

1013
use super::alp_rd::ALPRDCompressor;
1114
use crate::compressors::{CompressedArray, CompressionTree, EncodingCompressor};
@@ -41,7 +44,6 @@ impl EncodingCompressor for ALPCompressor {
4144
like: Option<CompressionTree<'a>>,
4245
ctx: SamplingCompressor<'a>,
4346
) -> VortexResult<CompressedArray<'a>> {
44-
// TODO(robert): Fill forward nulls?
4547
let parray = array.clone().into_primitive()?;
4648

4749
let (exponents, encoded, patches) = match_each_alp_float_ptype!(
@@ -72,6 +74,10 @@ impl EncodingCompressor for ALPCompressor {
7274
}
7375

7476
fn used_encodings(&self) -> HashSet<EncodingRef> {
75-
HashSet::from([&ALPEncoding as EncodingRef])
77+
HashSet::from([
78+
&ALPEncoding as EncodingRef,
79+
&ALPRDEncoding,
80+
&BitPackedEncoding,
81+
])
7682
}
7783
}

vortex-sampling-compressor/src/lib.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ use compressors::varbin::VarBinCompressor;
1414
use compressors::{CompressedArray, CompressorRef};
1515
use vortex_alp::{ALPEncoding, ALPRDEncoding};
1616
use vortex_array::array::{
17-
PrimitiveEncoding, SparseEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding,
17+
ListEncoding, PrimitiveEncoding, SparseEncoding, StructEncoding, VarBinEncoding,
18+
VarBinViewEncoding,
1819
};
1920
use vortex_array::encoding::EncodingRef;
2021
use vortex_array::Context;
@@ -32,6 +33,7 @@ use vortex_zigzag::ZigZagEncoding;
3233
use crate::compressors::alp::ALPCompressor;
3334
use crate::compressors::date_time_parts::DateTimePartsCompressor;
3435
use crate::compressors::dict::DictCompressor;
36+
use crate::compressors::list::ListCompressor;
3537
use crate::compressors::r#for::FoRCompressor;
3638
use crate::compressors::runend::DEFAULT_RUN_END_COMPRESSOR;
3739
use crate::compressors::runend_bool::RunEndBoolCompressor;
@@ -48,8 +50,6 @@ mod sampling_compressor;
4850

4951
pub use sampling_compressor::*;
5052

51-
use crate::compressors::list::ListCompressor;
52-
5353
pub const DEFAULT_COMPRESSORS: [CompressorRef; 15] = [
5454
&ALPCompressor as CompressorRef,
5555
&BITPACK_WITH_PATCHES,
@@ -72,7 +72,7 @@ pub const DEFAULT_COMPRESSORS: [CompressorRef; 15] = [
7272
];
7373

7474
#[cfg(not(target_arch = "wasm32"))]
75-
pub const ALL_COMPRESSORS: [CompressorRef; 17] = [
75+
pub const ALL_COMPRESSORS: [CompressorRef; 18] = [
7676
&ALPCompressor as CompressorRef,
7777
&BITPACK_WITH_PATCHES,
7878
&DEFAULT_CHUNKED_COMPRESSOR,
@@ -88,12 +88,13 @@ pub const ALL_COMPRESSORS: [CompressorRef; 17] = [
8888
&DEFAULT_RUN_END_COMPRESSOR,
8989
&SparseCompressor,
9090
&StructCompressor,
91+
&ListCompressor,
9192
&VarBinCompressor,
9293
&ZigZagCompressor,
9394
];
9495

9596
#[cfg(target_arch = "wasm32")]
96-
pub const ALL_COMPRESSORS: [CompressorRef; 15] = [
97+
pub const ALL_COMPRESSORS: [CompressorRef; 16] = [
9798
&ALPCompressor as CompressorRef,
9899
&BITPACK_WITH_PATCHES,
99100
&DEFAULT_CHUNKED_COMPRESSOR,
@@ -110,6 +111,7 @@ pub const ALL_COMPRESSORS: [CompressorRef; 15] = [
110111
&DEFAULT_RUN_END_COMPRESSOR,
111112
&SparseCompressor,
112113
&StructCompressor,
114+
&ListCompressor,
113115
&VarBinCompressor,
114116
&ZigZagCompressor,
115117
];
@@ -135,6 +137,7 @@ pub static ALL_ENCODINGS_CONTEXT: LazyLock<Arc<Context>> = LazyLock::new(|| {
135137
&RunEndBoolEncoding,
136138
&SparseEncoding,
137139
&StructEncoding,
140+
&ListEncoding,
138141
&VarBinEncoding,
139142
&VarBinViewEncoding,
140143
&ZigZagEncoding,

0 commit comments

Comments
 (0)