Skip to content

Commit 5b018b7

Browse files
authored
performance[vortex-datafusion]: support pushing down filters on struct columns (#5024)
The motivation for this PR is to allow pushing down filters of the type `my_struct.field = 'foo'`. There are some related fixes for things that broke when I added support for these, namely rewriting the filter on a physical struct schema that was a subset of the full table schema and right-hand side literals that are of the dictionary type. Please see commits for details. --------- Signed-off-by: Alfonso Subiotto Marques <[email protected]>
1 parent 3e94087 commit 5b018b7

File tree

6 files changed

+295
-6
lines changed

6 files changed

+295
-6
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ datafusion-common-runtime = { version = "50" }
102102
datafusion-datasource = { version = "50", default-features = false }
103103
datafusion-execution = { version = "50" }
104104
datafusion-expr = { version = "50" }
105+
datafusion-functions = { version = "50" }
105106
datafusion-physical-expr = { version = "50" }
106107
datafusion-physical-expr-adapter = { version = "50" }
107108
datafusion-physical-expr-common = { version = "50" }

vortex-datafusion/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ datafusion-common-runtime = { workspace = true }
2323
datafusion-datasource = { workspace = true, default-features = false }
2424
datafusion-execution = { workspace = true }
2525
datafusion-expr = { workspace = true }
26+
datafusion-functions = { workspace = true }
2627
datafusion-physical-expr = { workspace = true }
2728
datafusion-physical-expr-adapter = { workspace = true }
2829
datafusion-physical-expr-common = { workspace = true }

vortex-datafusion/src/convert/exprs.rs

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use std::sync::Arc;
55

66
use arrow_schema::{DataType, Schema};
77
use datafusion_expr::Operator as DFOperator;
8-
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
8+
use datafusion_functions::core::getfield::GetFieldFunc;
9+
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, ScalarFunctionExpr};
910
use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
1011
use datafusion_physical_plan::expressions as df_expr;
1112
use itertools::Itertools;
@@ -104,10 +105,47 @@ impl TryFromDataFusion<dyn PhysicalExpr> for ExprRef {
104105
return Ok(if in_list.negated() { not(expr) } else { expr });
105106
}
106107

108+
if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
109+
return try_convert_scalar_function(scalar_fn);
110+
}
111+
107112
vortex_bail!("Couldn't convert DataFusion physical {df} expression to a vortex expression")
108113
}
109114
}
110115

116+
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
117+
fn try_convert_scalar_function(scalar_fn: &ScalarFunctionExpr) -> VortexResult<ExprRef> {
118+
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn) {
119+
let source_expr = get_field_fn
120+
.args()
121+
.first()
122+
.ok_or_else(|| vortex_err!("get_field missing source expression"))?
123+
.as_ref();
124+
let field_name_expr = get_field_fn
125+
.args()
126+
.get(1)
127+
.ok_or_else(|| vortex_err!("get_field missing field name argument"))?;
128+
let field_name = field_name_expr
129+
.as_any()
130+
.downcast_ref::<df_expr::Literal>()
131+
.ok_or_else(|| vortex_err!("get_field field name must be a literal"))?
132+
.value()
133+
.try_as_str()
134+
.flatten()
135+
.ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?;
136+
return Ok(get_item(
137+
field_name.to_string(),
138+
ExprRef::try_from_df(source_expr)?,
139+
));
140+
}
141+
142+
tracing::debug!(
143+
function_name = scalar_fn.name(),
144+
"Unsupported ScalarFunctionExpr"
145+
);
146+
vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name())
147+
}
148+
111149
impl TryFromDataFusion<DFOperator> for Operator {
112150
fn try_from_df(value: &DFOperator) -> VortexResult<Self> {
113151
match value {
@@ -188,6 +226,9 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) ->
188226
} else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
189227
can_be_pushed_down(in_list.expr(), schema)
190228
&& in_list.list().iter().all(|e| can_be_pushed_down(e, schema))
229+
} else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
230+
// Only get_field pushdown is supported.
231+
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
191232
} else {
192233
tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
193234
false
@@ -203,6 +244,12 @@ fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> b
203244

204245
fn supported_data_types(dt: &DataType) -> bool {
205246
use DataType::*;
247+
248+
// For dictionary types, check if the value type is supported.
249+
if let Dictionary(_, value_type) = dt {
250+
return supported_data_types(value_type.as_ref());
251+
}
252+
206253
let is_supported = dt.is_null()
207254
|| dt.is_numeric()
208255
|| matches!(
@@ -232,9 +279,11 @@ fn supported_data_types(dt: &DataType) -> bool {
232279
mod tests {
233280
use std::sync::Arc;
234281

235-
use arrow_schema::{DataType, Field, Schema, TimeUnit as ArrowTimeUnit};
282+
use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit as ArrowTimeUnit};
283+
use datafusion::functions::core::getfield::GetFieldFunc;
236284
use datafusion_common::ScalarValue;
237-
use datafusion_expr::Operator as DFOperator;
285+
use datafusion_common::config::ConfigOptions;
286+
use datafusion_expr::{Operator as DFOperator, ScalarUDF};
238287
use datafusion_physical_expr::PhysicalExpr;
239288
use datafusion_physical_plan::expressions as df_expr;
240289
use insta::assert_snapshot;
@@ -415,6 +464,22 @@ mod tests {
415464
false
416465
)]
417466
#[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()), false)]
467+
// Dictionary types - should be supported if value type is supported
468+
#[case::dict_utf8(
469+
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
470+
true
471+
)]
472+
#[case::dict_int32(
473+
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
474+
true
475+
)]
476+
#[case::dict_unsupported(
477+
DataType::Dictionary(
478+
Box::new(DataType::UInt32),
479+
Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
480+
),
481+
false
482+
)]
418483
fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
419484
assert_eq!(supported_data_types(&data_type), expected);
420485
}
@@ -518,4 +583,53 @@ mod tests {
518583

519584
assert!(!can_be_pushed_down(&like_expr, &test_schema));
520585
}
586+
587+
#[test]
588+
fn test_expr_from_df_get_field() {
589+
let struct_col = Arc::new(df_expr::Column::new("my_struct", 0)) as Arc<dyn PhysicalExpr>;
590+
let field_name = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
591+
"field1".to_string(),
592+
)))) as Arc<dyn PhysicalExpr>;
593+
let get_field_expr = ScalarFunctionExpr::new(
594+
"get_field",
595+
Arc::new(ScalarUDF::from(GetFieldFunc::new())),
596+
vec![struct_col, field_name],
597+
Arc::new(Field::new("field1", DataType::Utf8, true)),
598+
Arc::new(ConfigOptions::new()),
599+
);
600+
let result = ExprRef::try_from_df(&get_field_expr).unwrap();
601+
assert_snapshot!(result.display_tree().to_string(), @r"
602+
GetItem(field1)
603+
└── GetItem(my_struct)
604+
└── Root
605+
");
606+
}
607+
608+
#[test]
609+
fn test_can_be_pushed_down_get_field() {
610+
let struct_fields = Fields::from(vec![
611+
Field::new("field1", DataType::Utf8, true),
612+
Field::new("field2", DataType::Int32, true),
613+
]);
614+
let schema = Schema::new(vec![Field::new(
615+
"my_struct",
616+
DataType::Struct(struct_fields),
617+
true,
618+
)]);
619+
620+
let struct_col = Arc::new(df_expr::Column::new("my_struct", 0)) as Arc<dyn PhysicalExpr>;
621+
let field_name = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
622+
"field1".to_string(),
623+
)))) as Arc<dyn PhysicalExpr>;
624+
625+
let get_field_expr = Arc::new(ScalarFunctionExpr::new(
626+
"get_field",
627+
Arc::new(ScalarUDF::from(GetFieldFunc::new())),
628+
vec![struct_col, field_name],
629+
Arc::new(Field::new("field1", DataType::Utf8, true)),
630+
Arc::new(ConfigOptions::new()),
631+
)) as Arc<dyn PhysicalExpr>;
632+
633+
assert!(can_be_pushed_down(&get_field_expr, &schema));
634+
}
521635
}

vortex-datafusion/src/convert/scalars.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ impl FromDataFusion<ScalarValue> for Scalar {
236236
Scalar::null(DType::Decimal(decimal_dtype, nullable))
237237
}
238238
}
239+
ScalarValue::Dictionary(_, v) => Scalar::from_df(v.as_ref()),
239240
_ => unimplemented!("Can't convert {value:?} value to a Vortex scalar"),
240241
}
241242
}

0 commit comments

Comments
 (0)