Skip to content

Commit 8bd2cbb

Browse files
authored
fix[vortex-expr]: delegate AnalysisExpr to CastExpr child (#5179)
checked_pruning_expr would otherwise return None which would result in no pruning using expressions that contained a CastExpr somewhere. It's relatively common for e.g. struct fields to be cast to the right type for comparison with literals by datafusion, and this expression translates to get_item(cast(struct), "field") = "value". This cast within a get_item would return None when field_path was called before this change. Signed-off-by: Alfonso Subiotto Marques <[email protected]>
1 parent 42a65ff commit 8bd2cbb

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

vortex-expr/src/exprs/cast.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
use vortex_array::compute::cast as compute_cast;
55
use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
6-
use vortex_dtype::DType;
6+
use vortex_dtype::{DType, FieldPath};
77
use vortex_error::{VortexResult, vortex_bail, vortex_err};
88
use vortex_proto::expr as pb;
99

1010
use crate::display::{DisplayAs, DisplayFormat};
11-
use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
11+
use crate::{
12+
AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
13+
};
1214

1315
vtable!(Cast);
1416

@@ -118,7 +120,23 @@ impl DisplayAs for CastExpr {
118120
}
119121
}
120122

121-
impl AnalysisExpr for CastExpr {}
123+
impl AnalysisExpr for CastExpr {
124+
fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
125+
self.child.max(catalog)
126+
}
127+
128+
fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
129+
self.child.min(catalog)
130+
}
131+
132+
fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
133+
self.child.nan_count(catalog)
134+
}
135+
136+
fn field_path(&self) -> Option<FieldPath> {
137+
self.child.field_path()
138+
}
139+
}
122140

123141
/// Creates an expression that casts values to a target data type.
124142
///

vortex-expr/src/pruning/pruning_expr.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ mod tests {
104104
use rstest::{fixture, rstest};
105105
use vortex_array::compute::{BetweenOptions, StrictComparison};
106106
use vortex_array::stats::Stat;
107-
use vortex_dtype::{FieldName, FieldPath, FieldPathSet};
107+
use vortex_dtype::{
108+
DType, FieldName, FieldNames, FieldPath, FieldPathSet, Nullability, StructFields,
109+
};
108110

109111
use crate::pruning::pruning_expr::HashMap;
110112
use crate::pruning::{checked_pruning_expr, field_path_stat_field_name};
111113
use crate::{
112-
HashSet, and, between, col, eq, get_item, gt, gt_eq, lit, lt, lt_eq, not_eq, or, root,
114+
HashSet, and, between, cast, col, eq, get_item, gt, gt_eq, lit, lt, lt_eq, not_eq, or, root,
113115
};
114116

115117
// Implement some checked pruning expressions.
@@ -491,4 +493,36 @@ mod tests {
491493
&or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
492494
);
493495
}
496+
497+
#[rstest]
498+
fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
499+
// This test verifies that cast properly forwards analysis methods to
500+
// enable pruning.
501+
let struct_dtype = DType::Struct(
502+
StructFields::new(
503+
FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
504+
vec![
505+
DType::Utf8(Nullability::Nullable),
506+
DType::Utf8(Nullability::Nullable),
507+
],
508+
),
509+
Nullability::NonNullable,
510+
);
511+
let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
512+
let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
513+
assert_eq!(
514+
refs.map(),
515+
&HashMap::from_iter([(
516+
FieldPath::from_name("a"),
517+
HashSet::from_iter([Stat::Min, Stat::Max])
518+
)])
519+
);
520+
assert_eq!(
521+
&converted,
522+
&or(
523+
gt(col("a_min"), lit("value")),
524+
gt(lit("value"), col("a_max"))
525+
)
526+
);
527+
}
494528
}

0 commit comments

Comments
 (0)