Skip to content

Commit f97c0cd

Browse files
authored
feat: teach VortexExpr to dtype (#1811)
Not entirely clear what to do about non-nullable extension types. The value should never be observed, but it also seems bad to create a Scalar whose value might violate the assumptions of the extension type.
1 parent e5deba0 commit f97c0cd

File tree

14 files changed

+340
-17
lines changed

14 files changed

+340
-17
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/src/data/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,15 @@ impl<T: AsRef<ArrayData>> ArrayDType for T {
425425
}
426426
}
427427

428+
impl ArrayData {
429+
pub fn into_dtype(self) -> DType {
430+
match self.0 {
431+
InnerArrayData::Owned(d) => d.dtype,
432+
InnerArrayData::Viewed(v) => v.dtype,
433+
}
434+
}
435+
}
436+
428437
impl<T: AsRef<ArrayData>> ArrayLen for T {
429438
fn len(&self) -> usize {
430439
self.as_ref().len()

vortex-expr/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ vortex-error = { workspace = true }
3535
vortex-proto = { workspace = true, optional = true }
3636
vortex-scalar = { workspace = true }
3737

38+
[dev-dependencies]
39+
vortex-expr = { path = ".", features = ["test-harness"] }
40+
41+
3842
[features]
3943
datafusion = [
4044
"dep:datafusion-expr",
@@ -48,3 +52,4 @@ proto = [
4852
"vortex-proto/expr",
4953
]
5054
serde = ["dep:serde", "vortex-dtype/serde", "vortex-scalar/serde"]
55+
test-harness = []

vortex-expr/src/binary.rs

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl VortexExpr for BinaryExpr {
4646
self
4747
}
4848

49-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
49+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
5050
let lhs = self.lhs.evaluate(batch)?;
5151
let rhs = self.rhs.evaluate(batch)?;
5252

@@ -257,3 +257,75 @@ pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
257257
pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
258258
BinaryExpr::new_expr(lhs, Operator::And, rhs)
259259
}
260+
261+
#[cfg(test)]
262+
mod tests {
263+
use std::sync::Arc;
264+
265+
use vortex_dtype::{DType, Nullability};
266+
267+
use crate::{and, col, eq, gt, gt_eq, lt, lt_eq, not_eq, or, test_harness, VortexExpr};
268+
269+
#[test]
270+
fn dtype() {
271+
let dtype = test_harness::struct_dtype();
272+
let bool1: Arc<dyn VortexExpr> = col("bool1");
273+
let bool2: Arc<dyn VortexExpr> = col("bool2");
274+
assert_eq!(
275+
and(bool1.clone(), bool2.clone())
276+
.return_dtype(&dtype)
277+
.unwrap(),
278+
DType::Bool(Nullability::NonNullable)
279+
);
280+
assert_eq!(
281+
or(bool1.clone(), bool2.clone())
282+
.return_dtype(&dtype)
283+
.unwrap(),
284+
DType::Bool(Nullability::NonNullable)
285+
);
286+
287+
let col1: Arc<dyn VortexExpr> = col("col1");
288+
let col2: Arc<dyn VortexExpr> = col("col2");
289+
290+
assert_eq!(
291+
eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
292+
DType::Bool(Nullability::Nullable)
293+
);
294+
assert_eq!(
295+
not_eq(col1.clone(), col2.clone())
296+
.return_dtype(&dtype)
297+
.unwrap(),
298+
DType::Bool(Nullability::Nullable)
299+
);
300+
assert_eq!(
301+
gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
302+
DType::Bool(Nullability::Nullable)
303+
);
304+
assert_eq!(
305+
gt_eq(col1.clone(), col2.clone())
306+
.return_dtype(&dtype)
307+
.unwrap(),
308+
DType::Bool(Nullability::Nullable)
309+
);
310+
assert_eq!(
311+
lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
312+
DType::Bool(Nullability::Nullable)
313+
);
314+
assert_eq!(
315+
lt_eq(col1.clone(), col2.clone())
316+
.return_dtype(&dtype)
317+
.unwrap(),
318+
DType::Bool(Nullability::Nullable)
319+
);
320+
321+
assert_eq!(
322+
or(
323+
lt(col1.clone(), col2.clone()),
324+
not_eq(col1.clone(), col2.clone())
325+
)
326+
.return_dtype(&dtype)
327+
.unwrap(),
328+
DType::Bool(Nullability::Nullable)
329+
);
330+
}
331+
}

vortex-expr/src/column.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ impl VortexExpr for Column {
5858
fn as_any(&self) -> &dyn Any {
5959
self
6060
}
61-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
61+
62+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
6263
batch
64+
.clone()
6365
.as_struct_array()
6466
.ok_or_else(|| {
6567
vortex_err!(
@@ -80,3 +82,23 @@ impl VortexExpr for Column {
8082
self
8183
}
8284
}
85+
86+
#[cfg(test)]
87+
mod tests {
88+
use vortex_dtype::{DType, Nullability, PType};
89+
90+
use crate::{col, test_harness};
91+
92+
#[test]
93+
fn dtype() {
94+
let dtype = test_harness::struct_dtype();
95+
assert_eq!(
96+
col("a").return_dtype(&dtype).unwrap(),
97+
DType::Primitive(PType::I32, Nullability::NonNullable)
98+
);
99+
assert_eq!(
100+
col(1).return_dtype(&dtype).unwrap(),
101+
DType::Primitive(PType::U16, Nullability::Nullable)
102+
);
103+
}
104+
}

vortex-expr/src/get_item.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ impl VortexExpr for GetItem {
4747
fn as_any(&self) -> &dyn Any {
4848
self
4949
}
50-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
50+
51+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
5152
let child = self.child.evaluate(batch)?;
5253
child
5354
.as_struct_array()

vortex-expr/src/identity.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ impl VortexExpr for Identity {
2727
self
2828
}
2929

30-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
30+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
3131
Ok(batch.clone())
3232
}
3333

@@ -45,3 +45,15 @@ impl VortexExpr for Identity {
4545
pub fn ident() -> ExprRef {
4646
Identity::new_expr()
4747
}
48+
49+
#[cfg(test)]
50+
mod tests {
51+
use crate::{ident, test_harness};
52+
53+
#[test]
54+
fn dtype() {
55+
let dtype = test_harness::struct_dtype();
56+
assert_eq!(ident().return_dtype(&dtype).unwrap(), dtype);
57+
assert_eq!(ident().return_dtype(&dtype).unwrap(), dtype);
58+
}
59+
}

vortex-expr/src/lib.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ pub use project::*;
3535
pub use row_filter::*;
3636
pub use select::*;
3737
use vortex_array::aliases::hash_set::HashSet;
38-
use vortex_array::ArrayData;
39-
use vortex_dtype::Field;
38+
use vortex_array::{ArrayDType as _, ArrayData, Canonical, IntoArrayData as _};
39+
use vortex_dtype::{DType, Field};
4040
use vortex_error::{VortexResult, VortexUnwrap};
4141

4242
use crate::traversal::{Node, ReferenceCollector};
@@ -49,11 +49,30 @@ pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display {
4949
fn as_any(&self) -> &dyn Any;
5050

5151
/// Compute result of expression on given batch producing a new batch
52-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData>;
52+
///
53+
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
54+
let result = self.unchecked_evaluate(batch)?;
55+
debug_assert_eq!(result.dtype(), &self.return_dtype(batch.dtype())?);
56+
Ok(result)
57+
}
58+
59+
/// Compute result of expression on given batch producing a new batch
60+
///
61+
/// "Unchecked" means that this function lacks a debug assertion that the returned array matches
62+
/// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
63+
/// includes such an assertion.
64+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData>;
5365

5466
fn children(&self) -> Vec<&ExprRef>;
5567

5668
fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
69+
70+
/// Compute the type of the array returned by [VortexExpr::evaluate].
71+
fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
72+
let empty = Canonical::empty(scope_dtype)?.into_array();
73+
self.unchecked_evaluate(&empty)
74+
.map(|array| array.into_dtype())
75+
}
5776
}
5877

5978
pub trait VortexExprExt {
@@ -112,6 +131,34 @@ impl Eq for dyn VortexExpr {}
112131

113132
dyn_hash::hash_trait_object!(VortexExpr);
114133

134+
#[cfg(feature = "test-harness")]
135+
pub mod test_harness {
136+
use vortex_dtype::{DType, Nullability, PType, StructDType};
137+
138+
pub fn struct_dtype() -> DType {
139+
DType::Struct(
140+
StructDType::new(
141+
[
142+
"a".into(),
143+
"col1".into(),
144+
"col2".into(),
145+
"bool1".into(),
146+
"bool2".into(),
147+
]
148+
.into(),
149+
vec![
150+
DType::Primitive(PType::I32, Nullability::NonNullable),
151+
DType::Primitive(PType::U16, Nullability::Nullable),
152+
DType::Primitive(PType::U16, Nullability::Nullable),
153+
DType::Bool(Nullability::NonNullable),
154+
DType::Bool(Nullability::NonNullable),
155+
],
156+
),
157+
Nullability::NonNullable,
158+
)
159+
}
160+
}
161+
115162
#[cfg(test)]
116163
mod tests {
117164
use vortex_dtype::{DType, Field, Nullability, PType, StructDType};

vortex-expr/src/like.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl VortexExpr for Like {
6161
self
6262
}
6363

64-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
64+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
6565
let child = self.child().evaluate(batch)?;
6666
let pattern = self.pattern().evaluate(batch)?;
6767
like(
@@ -102,8 +102,9 @@ impl PartialEq for Like {
102102
mod tests {
103103
use vortex_array::array::BoolArray;
104104
use vortex_array::IntoArrayVariant;
105+
use vortex_dtype::{DType, Nullability};
105106

106-
use crate::{ident, not};
107+
use crate::{ident, lit, not, Like};
107108

108109
#[test]
109110
fn invert_booleans() {
@@ -121,4 +122,14 @@ mod tests {
121122
vec![true, false, true, true, false, false]
122123
);
123124
}
125+
126+
#[test]
127+
fn dtype() {
128+
let dtype = DType::Utf8(Nullability::NonNullable);
129+
let like_expr = Like::new_expr(ident(), lit("%test%"), false, false);
130+
assert_eq!(
131+
like_expr.return_dtype(&dtype).unwrap(),
132+
DType::Bool(Nullability::NonNullable)
133+
);
134+
}
124135
}

vortex-expr/src/literal.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ impl VortexExpr for Literal {
3737
self
3838
}
3939

40-
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
40+
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
4141
Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array())
4242
}
4343

@@ -72,3 +72,65 @@ impl VortexExpr for Literal {
7272
pub fn lit(value: impl Into<Scalar>) -> ExprRef {
7373
Literal::new_expr(value.into())
7474
}
75+
76+
#[cfg(test)]
77+
mod tests {
78+
use std::sync::Arc;
79+
80+
use vortex_dtype::{DType, Nullability, PType, StructDType};
81+
use vortex_scalar::Scalar;
82+
83+
use crate::{lit, test_harness};
84+
85+
#[test]
86+
fn dtype() {
87+
let dtype = test_harness::struct_dtype();
88+
89+
assert_eq!(
90+
lit(10).return_dtype(&dtype).unwrap(),
91+
DType::Primitive(PType::I32, Nullability::NonNullable)
92+
);
93+
assert_eq!(
94+
lit(0_u8).return_dtype(&dtype).unwrap(),
95+
DType::Primitive(PType::U8, Nullability::NonNullable)
96+
);
97+
assert_eq!(
98+
lit(0.0_f32).return_dtype(&dtype).unwrap(),
99+
DType::Primitive(PType::F32, Nullability::NonNullable)
100+
);
101+
assert_eq!(
102+
lit(i64::MAX).return_dtype(&dtype).unwrap(),
103+
DType::Primitive(PType::I64, Nullability::NonNullable)
104+
);
105+
assert_eq!(
106+
lit(true).return_dtype(&dtype).unwrap(),
107+
DType::Bool(Nullability::NonNullable)
108+
);
109+
assert_eq!(
110+
lit(Scalar::null(DType::Bool(Nullability::Nullable)))
111+
.return_dtype(&dtype)
112+
.unwrap(),
113+
DType::Bool(Nullability::Nullable)
114+
);
115+
116+
let sdtype = DType::Struct(
117+
StructDType::new(
118+
Arc::from([Arc::from("dog"), Arc::from("cat")]),
119+
vec![
120+
DType::Primitive(PType::U32, Nullability::NonNullable),
121+
DType::Utf8(Nullability::NonNullable),
122+
],
123+
),
124+
Nullability::NonNullable,
125+
);
126+
assert_eq!(
127+
lit(Scalar::struct_(
128+
sdtype.clone(),
129+
vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
130+
))
131+
.return_dtype(&dtype)
132+
.unwrap(),
133+
sdtype
134+
);
135+
}
136+
}

0 commit comments

Comments
 (0)