Skip to content

Commit d72618d

Browse files
chore[vortex-dict]: numeric pushdown & clean up the dict ctor (#5438)
Signed-off-by: Joe Isaacs <[email protected]> Co-authored-by: Robert Kruszewski <[email protected]>
1 parent fd5a013 commit d72618d

File tree

4 files changed

+181
-21
lines changed

4 files changed

+181
-21
lines changed

vortex-array/src/arrays/dict/array.rs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ pub struct DictArray {
125125
/// Indicates whether all dictionary values are definitely referenced by at least one code.
126126
/// `true` = all values are referenced (computed during encoding).
127127
/// `false` = unknown/might have unreferenced values.
128+
/// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
129+
/// incorrect behaviour.
128130
all_values_referenced: bool,
129131
}
130132

@@ -192,26 +194,11 @@ impl DictArray {
192194
///
193195
/// It is an error to provide a nullable `codes` with non-nullable `values`.
194196
pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
195-
Self::try_new_with_metadata(codes, values, false)
196-
}
197-
198-
/// Build a new `DictArray` from its components with explicit metadata.
199-
///
200-
/// Same as [`DictArray::try_new`] but allows specifying whether all values are referenced.
201-
/// This is typically only set to `true` during dictionary encoding when we know for certain
202-
/// that all dictionary values are referenced by at least one code.
203-
pub fn try_new_with_metadata(
204-
codes: ArrayRef,
205-
values: ArrayRef,
206-
all_values_referenced: bool,
207-
) -> VortexResult<Self> {
208197
if !codes.dtype().is_unsigned_int() {
209198
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
210199
}
211200

212-
Ok(unsafe {
213-
Self::new_unchecked(codes, values).set_all_values_referenced(all_values_referenced)
214-
})
201+
Ok(unsafe { Self::new_unchecked(codes, values) })
215202
}
216203

217204
#[inline]
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_scalar::NumericOperator;
6+
7+
use super::{DictArray, DictVTable};
8+
use crate::arrays::ConstantArray;
9+
use crate::compute::{NumericKernel, NumericKernelAdapter, numeric};
10+
use crate::{Array, ArrayRef, IntoArray, register_kernel};
11+
12+
impl NumericKernel for DictVTable {
13+
fn numeric(
14+
&self,
15+
lhs: &DictArray,
16+
rhs: &dyn Array,
17+
op: NumericOperator,
18+
) -> VortexResult<Option<ArrayRef>> {
19+
// If we have more values than codes, it is faster to canonicalise first.
20+
if lhs.values().len() > lhs.codes().len() {
21+
return Ok(None);
22+
}
23+
24+
// Only push down if all values are referenced to avoid incorrect results
25+
// See: https://github.com/vortex-data/vortex/pull/4560
26+
// Unchecked operation will be fine to pushdown.
27+
if !lhs.has_all_values_referenced() {
28+
return Ok(None);
29+
}
30+
31+
// If the RHS is constant, then we just need to apply the operation to our encoded values.
32+
if let Some(rhs_scalar) = rhs.as_constant() {
33+
let values_result = numeric(
34+
lhs.values(),
35+
ConstantArray::new(rhs_scalar, lhs.values().len()).as_ref(),
36+
op,
37+
)?;
38+
39+
// SAFETY: values len preserved, codes all still point to valid values
40+
// all_values_referenced preserved since operation doesn't change which values are referenced
41+
let result = unsafe {
42+
DictArray::new_unchecked(lhs.codes().clone(), values_result)
43+
.set_all_values_referenced(lhs.has_all_values_referenced())
44+
.into_array()
45+
};
46+
47+
return Ok(Some(result));
48+
}
49+
50+
// It's a little more complex, but we could perform binary operations against the dictionary
51+
// values in the future.
52+
Ok(None)
53+
}
54+
}
55+
56+
register_kernel!(NumericKernelAdapter(DictVTable).lift());
57+
58+
#[cfg(test)]
59+
mod tests {
60+
use vortex_buffer::buffer;
61+
use vortex_scalar::NumericOperator;
62+
63+
use crate::arrays::dict::DictArray;
64+
use crate::arrays::{ConstantArray, PrimitiveArray};
65+
use crate::compute::numeric;
66+
use crate::{IntoArray, assert_arrays_eq};
67+
68+
#[test]
69+
fn test_add_const() {
70+
// Create a dict with all_values_referenced = true
71+
let dict = unsafe {
72+
DictArray::new_unchecked(
73+
buffer![0u32, 1, 2, 0, 1].into_array(),
74+
buffer![10i32, 20, 30].into_array(),
75+
)
76+
.set_all_values_referenced(true)
77+
};
78+
79+
let res = numeric(
80+
dict.as_ref(),
81+
ConstantArray::new(5i32, 5).as_ref(),
82+
NumericOperator::Add,
83+
)
84+
.unwrap();
85+
86+
let expected = PrimitiveArray::from_iter([15i32, 25, 35, 15, 25]);
87+
assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
88+
}
89+
90+
#[test]
91+
fn test_mul_const() {
92+
// Create a dict with all_values_referenced = true
93+
let dict = unsafe {
94+
DictArray::new_unchecked(
95+
buffer![0u32, 1, 2, 1, 0].into_array(),
96+
buffer![2i32, 3, 5].into_array(),
97+
)
98+
.set_all_values_referenced(true)
99+
};
100+
101+
let res = numeric(
102+
dict.as_ref(),
103+
ConstantArray::new(10i32, 5).as_ref(),
104+
NumericOperator::Mul,
105+
)
106+
.unwrap();
107+
108+
let expected = PrimitiveArray::from_iter([20i32, 30, 50, 30, 20]);
109+
assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
110+
}
111+
112+
#[test]
113+
fn test_no_pushdown_when_not_all_values_referenced() {
114+
// Create a dict with all_values_referenced = false (default)
115+
let dict = DictArray::try_new(
116+
buffer![0u32, 1, 0, 1].into_array(),
117+
buffer![10i32, 20, 30].into_array(), // value at index 2 is not referenced
118+
)
119+
.unwrap();
120+
121+
// Should return None, indicating no pushdown
122+
let res = numeric(
123+
dict.as_ref(),
124+
ConstantArray::new(5i32, 4).as_ref(),
125+
NumericOperator::Add,
126+
)
127+
.unwrap();
128+
129+
// Verify the result by canonicalizing
130+
let expected = PrimitiveArray::from_iter([15i32, 25, 15, 25]);
131+
assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
132+
}
133+
134+
#[test]
135+
fn test_sub_const() {
136+
// Create a dict with all_values_referenced = true
137+
let dict = unsafe {
138+
DictArray::new_unchecked(
139+
buffer![0u32, 1, 2].into_array(),
140+
buffer![100i32, 50, 25].into_array(),
141+
)
142+
.set_all_values_referenced(true)
143+
};
144+
145+
let res = numeric(
146+
dict.as_ref(),
147+
ConstantArray::new(10i32, 3).as_ref(),
148+
NumericOperator::Sub,
149+
)
150+
.unwrap();
151+
152+
let expected = PrimitiveArray::from_iter([90i32, 40, 15]);
153+
assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
154+
}
155+
}

vortex-array/src/arrays/dict/compute/mod.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
mod binary_numeric;
45
mod cast;
56
mod compare;
67
mod fill_null;
@@ -22,7 +23,12 @@ impl TakeKernel for DictVTable {
2223
fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
2324
let codes = take(array.codes(), indices)?;
2425
// SAFETY: selecting codes doesn't change the invariants of DictArray
25-
Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()) }.into_array())
26+
// Preserve all_values_referenced since taking codes doesn't affect which values are referenced
27+
Ok(unsafe {
28+
DictArray::new_unchecked(codes, array.values().clone())
29+
.set_all_values_referenced(array.has_all_values_referenced())
30+
.into_array()
31+
})
2632
}
2733
}
2834

@@ -33,7 +39,12 @@ impl FilterKernel for DictVTable {
3339
let codes = filter(array.codes(), mask)?;
3440

3541
// SAFETY: filtering codes doesn't change invariants
36-
unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
42+
// Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
43+
unsafe {
44+
Ok(DictArray::new_unchecked(codes, array.values().clone())
45+
.set_all_values_referenced(array.has_all_values_referenced())
46+
.into_array())
47+
}
3748
}
3849
}
3950

vortex-layout/src/layouts/dict/reader.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,16 @@ impl LayoutReader for DictReader {
200200
Ok(async move {
201201
let (values, codes) = try_join!(values_eval.map_err(VortexError::from), codes_eval)?;
202202

203-
// Validate that codes are valid for the values
204-
let array =
205-
DictArray::try_new_with_metadata(codes, values, all_values_referenced)?.to_array();
203+
// SAFETY: Layout was validated at write time.
204+
// * The codes dtype is guaranteed to be an unsigned integer type from the layout
205+
// * The codes child reader ensures the correct dtype.
206+
// * The layout stores `all_values_referenced` and if this is malicious then it must
207+
// only affect correctness not memory safety.
208+
let array = unsafe {
209+
DictArray::new_unchecked(codes, values)
210+
.set_all_values_referenced(all_values_referenced)
211+
}
212+
.to_array();
206213
expr.evaluate(&array)
207214
}
208215
.boxed())

0 commit comments

Comments
 (0)