Skip to content

Commit 8b7fce4

Browse files
committed
fix[vortex-dict]: avoid full materialization in min_max
Prior to this commit, min_max on dictionaries would materialize the full dictionary. This commit iterates over the codes and checks whether they fully cover the values slice, in which case min_max can be run on the values slice. Even if the codes do not fully cover the values, we can materialize without duplicates, by filtering the values slice to only the covered indices. Signed-off-by: Alfonso Subiotto Marques <[email protected]>
1 parent fcf2d98 commit 8b7fce4

File tree

4 files changed

+129
-9
lines changed

4 files changed

+129
-9
lines changed

Cargo.lock

Lines changed: 0 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/dict/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ arrow = ["dep:arrow-array"]
2020

2121
[dependencies]
2222
arrow-array = { workspace = true, optional = true }
23-
arrow-buffer = { workspace = true }
24-
num-traits = { workspace = true }
2523
prost = { workspace = true }
2624
# test-harness
2725
rand = { workspace = true, optional = true }
Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,140 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult, min_max, take};
5-
use vortex_array::register_kernel;
4+
use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult, mask, min_max};
5+
use vortex_array::{Array as _, ToCanonical, register_kernel};
6+
use vortex_buffer::BitBufferMut;
7+
use vortex_dtype::match_each_unsigned_integer_ptype;
68
use vortex_error::VortexResult;
9+
use vortex_mask::Mask;
710

811
use crate::{DictArray, DictVTable};
912

1013
impl MinMaxKernel for DictVTable {
1114
fn min_max(&self, array: &DictArray) -> VortexResult<Option<MinMaxResult>> {
12-
min_max(&take(array.values(), array.codes())?)
15+
let codes_validity = array.codes().validity_mask();
16+
if codes_validity.all_false() {
17+
return Ok(None);
18+
}
19+
20+
let codes_primitive = array.codes().to_primitive();
21+
let values_len = array.values().len();
22+
match_each_unsigned_integer_ptype!(codes_primitive.ptype(), |P| {
23+
codes_validity.iter_bools(|validity_iter| {
24+
// mask() sets values to null where the mask is true, so we start
25+
// with a fully-set bit buffer.
26+
let mut unreferenced = BitBufferMut::new_set(values_len);
27+
for (&code, is_valid) in codes_primitive.as_slice::<P>().iter().zip(validity_iter) {
28+
if is_valid {
29+
// SAFETY: code is valid, so it must be in range.
30+
#[allow(clippy::cast_possible_truncation)]
31+
unsafe {
32+
unreferenced.unset_unchecked(code as usize);
33+
}
34+
}
35+
}
36+
37+
let unreferenced_mask = Mask::from_buffer(unreferenced.freeze());
38+
min_max(&mask(array.values(), &unreferenced_mask)?)
39+
})
40+
})
1341
}
1442
}
1543

1644
register_kernel!(MinMaxKernelAdapter(DictVTable).lift());
45+
46+
#[cfg(test)]
47+
mod tests {
48+
use rstest::rstest;
49+
use vortex_array::arrays::PrimitiveArray;
50+
use vortex_array::compute::min_max;
51+
use vortex_array::{Array, IntoArray};
52+
use vortex_buffer::buffer;
53+
54+
use crate::DictArray;
55+
use crate::builders::dict_encode;
56+
57+
fn assert_min_max(array: &dyn Array, expected: Option<(i32, i32)>) {
58+
match (min_max(array).unwrap(), expected) {
59+
(Some(result), Some((expected_min, expected_max))) => {
60+
assert_eq!(i32::try_from(result.min).unwrap(), expected_min);
61+
assert_eq!(i32::try_from(result.max).unwrap(), expected_max);
62+
}
63+
(None, None) => {}
64+
(got, expected) => panic!(
65+
"min_max mismatch: expected {:?}, got {:?}",
66+
expected,
67+
got.as_ref().map(|r| (
68+
i32::try_from(r.min.clone()).ok(),
69+
i32::try_from(r.max.clone()).ok()
70+
))
71+
),
72+
}
73+
}
74+
75+
#[rstest]
76+
#[case::covering(
77+
DictArray::try_new(
78+
buffer![0u32, 1, 2, 3, 0, 1].into_array(),
79+
buffer![10i32, 20, 30, 40].into_array(),
80+
).unwrap(),
81+
(10, 40)
82+
)]
83+
#[case::non_covering_duplicates(
84+
DictArray::try_new(
85+
buffer![1u32, 1, 1, 3, 3].into_array(),
86+
buffer![1i32, 2, 3, 4, 5].into_array(),
87+
).unwrap(),
88+
(2, 4)
89+
)]
90+
// Non-covering: codes with gaps
91+
#[case::non_covering_gaps(
92+
DictArray::try_new(
93+
buffer![0u32, 2, 4].into_array(),
94+
buffer![1i32, 2, 3, 4, 5].into_array(),
95+
).unwrap(),
96+
(1, 5)
97+
)]
98+
#[case::single(dict_encode(&buffer![42i32].into_array()).unwrap(), (42, 42))]
99+
#[case::nullable_codes(
100+
DictArray::try_new(
101+
PrimitiveArray::from_option_iter([Some(0u32), None, Some(1), Some(2)]).into_array(),
102+
buffer![10i32, 20, 30].into_array(),
103+
).unwrap(),
104+
(10, 30)
105+
)]
106+
#[case::nullable_values(
107+
dict_encode(
108+
PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
109+
).unwrap(),
110+
(1, 2)
111+
)]
112+
fn test_min_max(#[case] dict: DictArray, #[case] expected: (i32, i32)) {
113+
assert_min_max(dict.as_ref(), Some(expected));
114+
}
115+
116+
#[test]
117+
fn test_sliced_dict() {
118+
let reference = PrimitiveArray::from_iter([1, 5, 10, 50, 100]);
119+
let dict = dict_encode(reference.as_ref()).unwrap();
120+
let sliced = dict.slice(1..3);
121+
assert_min_max(&sliced, Some((5, 10)));
122+
}
123+
124+
#[rstest]
125+
#[case::empty(
126+
DictArray::try_new(
127+
PrimitiveArray::from_iter(Vec::<u32>::new()).into_array(),
128+
buffer![10i32, 20, 30].into_array(),
129+
).unwrap()
130+
)]
131+
#[case::all_null_codes(
132+
DictArray::try_new(
133+
PrimitiveArray::from_option_iter([Option::<u32>::None, None, None]).into_array(),
134+
buffer![10i32, 20, 30].into_array(),
135+
).unwrap()
136+
)]
137+
fn test_min_max_none(#[case] dict: DictArray) {
138+
assert_min_max(dict.as_ref(), None);
139+
}
140+
}

vortex-buffer/src/bit/buf_mut.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ impl BitBufferMut {
234234
/// # Safety
235235
///
236236
/// The caller must ensure that `index` does not exceed the largest bit index in the backing buffer.
237-
unsafe fn set_unchecked(&mut self, index: usize) {
237+
pub unsafe fn set_unchecked(&mut self, index: usize) {
238238
// SAFETY: checked by caller
239239
unsafe { set_bit_unchecked(self.buffer.as_mut_ptr(), self.offset + index) }
240240
}
@@ -246,7 +246,7 @@ impl BitBufferMut {
246246
/// # Safety
247247
///
248248
/// The caller must ensure that `index` does not exceed the largest bit index in the backing buffer.
249-
unsafe fn unset_unchecked(&mut self, index: usize) {
249+
pub unsafe fn unset_unchecked(&mut self, index: usize) {
250250
// SAFETY: checked by caller
251251
unsafe { unset_bit_unchecked(self.buffer.as_mut_ptr(), self.offset + index) }
252252
}

0 commit comments

Comments
 (0)