Skip to content

Commit 4e42c68

Browse files
committed
perf[dict]: all_values_referenced to allow pushdown optimisation.
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 962d840 commit 4e42c68

File tree

15 files changed

+114
-70
lines changed

15 files changed

+114
-70
lines changed

vortex-array/src/arrays/dict/array.rs

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ impl VTable for DictVTable {
109109
let values = children.get(1, dtype, metadata.values_len as usize)?;
110110
let all_values_referenced = metadata.all_values_referenced.unwrap_or(false);
111111

112-
DictArray::try_new_with_metadata(codes, values, all_values_referenced)
112+
// SAFETY: We've validated the metadata and children
113+
Ok(unsafe {
114+
DictArray::new_unchecked(codes, values).set_all_values_referenced(all_values_referenced)
115+
})
113116
}
114117
}
115118

@@ -135,11 +138,7 @@ impl DictArray {
135138
/// This should be called only when you can guarantee the invariants checked
136139
/// by the safe [`DictArray::try_new`] constructor are valid, for example when
137140
/// you are filtering or slicing an existing valid `DictArray`.
138-
pub unsafe fn new_unchecked(
139-
codes: ArrayRef,
140-
values: ArrayRef,
141-
all_values_referenced: bool,
142-
) -> Self {
141+
pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
143142
let dtype = values
144143
.dtype()
145144
.union_nullability(codes.dtype().nullability());
@@ -148,8 +147,35 @@ impl DictArray {
148147
values,
149148
stats_set: Default::default(),
150149
dtype,
151-
all_values_referenced,
150+
all_values_referenced: false,
151+
}
152+
}
153+
154+
/// Set whether all dictionary values are definitely referenced.
155+
///
156+
/// # Safety
157+
/// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
158+
/// values are actually referenced by at least one valid code. Setting this incorrectly can
159+
/// lead to incorrect query results in operations like min/max.
160+
///
161+
/// This is typically only set to `true` during dictionary encoding when we know for certain
162+
/// that all values are referenced.
163+
pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
164+
// In debug builds, verify the claim when setting to true
165+
#[cfg(debug_assertions)]
166+
if all_values_referenced {
167+
if let Ok(unreferenced_mask) = self.compute_unreferenced_values_mask(false) {
168+
let has_unreferenced = unreferenced_mask.iter().any(|b| b);
169+
debug_assert!(
170+
!has_unreferenced,
171+
"set_all_values_referenced(true) called but {} unreferenced values found",
172+
unreferenced_mask.iter().filter(|&b| b).count()
173+
);
174+
}
152175
}
176+
177+
self.all_values_referenced = all_values_referenced;
178+
self
153179
}
154180

155181
/// Build a new `DictArray` from its components, `codes` and `values`.
@@ -189,7 +215,9 @@ impl DictArray {
189215
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
190216
}
191217

192-
Ok(unsafe { Self::new_unchecked(codes, values, all_values_referenced) })
218+
Ok(unsafe {
219+
Self::new_unchecked(codes, values).set_all_values_referenced(all_values_referenced)
220+
})
193221
}
194222

195223
#[inline]
@@ -212,24 +240,56 @@ impl DictArray {
212240
self.all_values_referenced
213241
}
214242

243+
/// Validates that the `all_values_referenced` flag matches reality.
244+
///
245+
/// Returns `Ok(())` if the flag is consistent with the actual referenced values,
246+
/// or an error describing the mismatch.
247+
///
248+
/// This is primarily useful for testing and debugging.
249+
#[cfg(debug_assertions)]
250+
pub fn validate_all_values_referenced(&self) -> VortexResult<()> {
251+
let unreferenced_mask = self.compute_unreferenced_values_mask(false)?;
252+
let has_unreferenced = unreferenced_mask.iter().any(|b| b);
253+
let actual_all_referenced = !has_unreferenced;
254+
255+
if self.all_values_referenced && !actual_all_referenced {
256+
let unreferenced_count = unreferenced_mask.iter().filter(|&b| b).count();
257+
vortex_bail!(
258+
"all_values_referenced=true but {} unreferenced values found",
259+
unreferenced_count
260+
);
261+
}
262+
263+
Ok(())
264+
}
265+
215266
/// Compute a mask indicating which values in the dictionary are referenced by at least one code.
216267
///
217-
/// Returns a `BitBuffer` where unset bits (false) correspond to values that are referenced
218-
/// by at least one valid code, and set bits (true) correspond to unreferenced values.
268+
/// When `referenced = true`, returns a `BitBuffer` where set bits (true) correspond to
269+
/// referenced values, and unset bits (false) correspond to unreferenced values.
270+
///
271+
/// When `referenced = false` (default for unreferenced values), returns the inverse:
272+
/// set bits (true) correspond to unreferenced values, and unset bits (false) correspond
273+
/// to referenced values.
219274
///
220275
/// This is useful for operations like min/max that need to ignore unreferenced values.
221-
pub fn compute_unreferenced_values_mask(&self) -> VortexResult<BitBuffer> {
276+
pub fn compute_unreferenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
222277
let codes_validity = self.codes().validity_mask();
223278
let codes_primitive = self.codes().to_primitive();
224279
let values_len = self.values().len();
225280

226-
let mut unreferenced_vec = vec![true; values_len];
281+
// Initialize with the starting value: false for referenced, true for unreferenced
282+
let init_value = !referenced;
283+
// Value to set when we find a referenced code: true for referenced, false for unreferenced
284+
let referenced_value = referenced;
285+
286+
let mut values_vec = vec![init_value; values_len];
227287
match codes_validity.bit_buffer() {
228288
AllOr::All => {
229289
match_each_integer_ptype!(codes_primitive.ptype(), |P| {
230290
#[allow(clippy::cast_possible_truncation)]
231291
for &code in codes_primitive.as_slice::<P>().iter() {
232-
unreferenced_vec[code as usize] = false;
292+
values_vec[code as usize] = referenced_value;
233293
}
234294
});
235295
}
@@ -240,15 +300,13 @@ impl DictArray {
240300

241301
#[allow(clippy::cast_possible_truncation)]
242302
buf.set_indices().for_each(|idx| {
243-
unreferenced_vec[codes[idx] as usize] = false;
303+
values_vec[codes[idx] as usize] = referenced_value;
244304
})
245305
});
246306
}
247307
}
248308

249-
Ok(BitBuffer::collect_bool(values_len, |idx| {
250-
unreferenced_vec[idx]
251-
}))
309+
Ok(BitBuffer::collect_bool(values_len, |idx| values_vec[idx]))
252310
}
253311
}
254312

vortex-array/src/arrays/dict/arrow.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ impl<K: ArrowDictionaryKeyType> FromArrowArray<&DictionaryArray<K>> for DictArra
1414
let keys = ArrayRef::from_arrow(keys, keys.is_nullable());
1515
let values = ArrayRef::from_arrow(array.values().as_ref(), nullable);
1616
// SAFETY: we assume that Arrow has checked the invariants on construction
17-
// We conservatively set all_values_referenced to false since Arrow doesn't provide this info
18-
unsafe { DictArray::new_unchecked(keys, values, false) }
17+
unsafe { DictArray::new_unchecked(keys, values) }
1918
}
2019
}

vortex-array/src/arrays/dict/compute/cast.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@ impl CastKernel for DictVTable {
2424
};
2525

2626
// SAFETY: casting does not alter invariants of the codes
27+
// Preserve all_values_referenced since casting only changes values, not which are referenced
2728
Ok(Some(
2829
unsafe {
29-
DictArray::new_unchecked(
30-
casted_codes,
31-
casted_values,
32-
array.has_all_values_referenced(),
33-
)
30+
DictArray::new_unchecked(casted_codes, casted_values)
31+
.set_all_values_referenced(array.has_all_values_referenced())
3432
}
3533
.into_array(),
3634
))

vortex-array/src/arrays/dict/compute/compare.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,9 @@ impl CompareKernel for DictVTable {
3030

3131
// SAFETY: values len preserved, codes all still point to valid values
3232
let result = unsafe {
33-
DictArray::new_unchecked(
34-
lhs.codes().clone(),
35-
compare_result,
36-
lhs.has_all_values_referenced(),
37-
)
38-
.into_array()
33+
DictArray::new_unchecked(lhs.codes().clone(), compare_result)
34+
.set_all_values_referenced(lhs.has_all_values_referenced())
35+
.into_array()
3936
};
4037

4138
// We canonicalize the result because dictionary-encoded bools is dumb.

vortex-array/src/arrays/dict/compute/fill_null.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ impl FillNullKernel for DictVTable {
4343

4444
// SAFETY: invariants are still satisfied after patching nulls
4545
unsafe {
46-
Ok(
47-
DictArray::new_unchecked(codes, values, array.has_all_values_referenced())
48-
.into_array(),
49-
)
46+
Ok(DictArray::new_unchecked(codes, values)
47+
// Preserve all_values_referenced since filling nulls cannot make values
48+
// unreferenced.
49+
.set_all_values_referenced(array.has_all_values_referenced())
50+
.into_array())
5051
}
5152
}
5253
}

vortex-array/src/arrays/dict/compute/like.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ impl LikeKernel for DictVTable {
2828
// Preserve all_values_referenced since codes are unchanged
2929
unsafe {
3030
Ok(Some(
31-
DictArray::new_unchecked(
32-
array.codes().clone(),
33-
values,
34-
array.has_all_values_referenced(),
35-
)
36-
.into_array(),
31+
DictArray::new_unchecked(array.codes().clone(), values)
32+
.set_all_values_referenced(array.has_all_values_referenced())
33+
.into_array(),
3734
))
3835
}
3936
} else {

vortex-array/src/arrays/dict/compute/min_max.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl MinMaxKernel for DictVTable {
2121
}
2222

2323
// Slow path: compute which values are unreferenced and mask them out
24-
let unreferenced_mask = Mask::from_buffer(array.compute_unreferenced_values_mask()?);
24+
let unreferenced_mask = Mask::from_buffer(array.compute_unreferenced_values_mask(false)?);
2525
min_max(&mask(array.values(), &unreferenced_mask)?)
2626
}
2727
}

vortex-array/src/arrays/dict/compute/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ impl TakeKernel for DictVTable {
2222
fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
2323
let codes = take(array.codes(), indices)?;
2424
// SAFETY: selecting codes doesn't change the invariants of DictArray
25-
// We conservatively set all_values_referenced to false since taking may leave values unreferenced
26-
Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone(), false) }.into_array())
25+
Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()) }.into_array())
2726
}
2827
}
2928

@@ -34,8 +33,7 @@ impl FilterKernel for DictVTable {
3433
let codes = filter(array.codes(), mask)?;
3534

3635
// SAFETY: filtering codes doesn't change invariants
37-
// We conservatively set all_values_referenced to false since filtering may leave values unreferenced
38-
unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone(), false).into_array()) }
36+
unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
3937
}
4038
}
4139

vortex-array/src/arrays/dict/ops.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl OperationsVTable<DictVTable> for DictVTable {
2424
};
2525
}
2626
// SAFETY: slicing the codes preserves invariants
27-
unsafe { DictArray::new_unchecked(sliced_code, array.values().clone(), false).into_array() }
27+
unsafe { DictArray::new_unchecked(sliced_code, array.values().clone()).into_array() }
2828
}
2929

3030
fn scalar_at(array: &DictArray, index: usize) -> Scalar {

vortex-array/src/builders/dict/mod.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,13 @@ pub fn dict_encode_with_constraints(
5353
let mut encoder = dict_encoder(array, constraints);
5454
let codes = encoder.encode(array).to_primitive().narrow()?;
5555
// SAFETY: The encoding process will produce a value set of codes and values
56+
// All values in the dictionary are guaranteed to be referenced by at least one code
57+
// since we build the dictionary from the codes we observe during encoding
5658
unsafe {
57-
Ok(DictArray::new_unchecked(
58-
codes.into_array(),
59-
encoder.reset(),
60-
// All values in the dictionary are guaranteed to be referenced by at least one code
61-
// since we build the dictionary from the codes we observe during encoding
62-
true,
63-
))
59+
Ok(
60+
DictArray::new_unchecked(codes.into_array(), encoder.reset())
61+
.set_all_values_referenced(true),
62+
)
6463
}
6564
}
6665

0 commit comments

Comments
 (0)