Skip to content

Commit 1cb7a39

Browse files
authored
fix[vortex-datafusion]: check field exists in get_field pushdown (#5295)
Previously, any get field expression was pushed down, which resulted in incorrectly pushing down get fields on non-existent fields, causing an error at execution time rather than plan time. This was a mistake on my part that should've been included in #5024 Signed-off-by: Alfonso Subiotto Marques <[email protected]>
1 parent e262a87 commit 1cb7a39

File tree

1 file changed

+57
-9
lines changed

1 file changed

+57
-9
lines changed

vortex-datafusion/src/convert/exprs.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,7 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) ->
229229
can_be_pushed_down(in_list.expr(), schema)
230230
&& in_list.list().iter().all(|e| can_be_pushed_down(e, schema))
231231
} else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
232-
// Only get_field pushdown is supported.
233-
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
232+
can_scalar_fn_be_pushed_down(scalar_fn, schema)
234233
} else {
235234
tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
236235
false
@@ -277,6 +276,53 @@ fn supported_data_types(dt: &DataType) -> bool {
277276
is_supported
278277
}
279278

279+
/// Checks if a GetField scalar function can be pushed down.
280+
fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
281+
let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
282+
else {
283+
// Only get_field pushdown is supported.
284+
return false;
285+
};
286+
287+
let args = get_field_fn.args();
288+
if args.len() != 2 {
289+
tracing::debug!(
290+
"Expected 2 arguments for GetField, not pushing down {} arguments",
291+
args.len()
292+
);
293+
return false;
294+
}
295+
let source_expr = &args[0];
296+
let field_name_expr = &args[1];
297+
let Some(field_name) = field_name_expr
298+
.as_any()
299+
.downcast_ref::<df_expr::Literal>()
300+
.and_then(|lit| lit.value().try_as_str().flatten())
301+
else {
302+
return false;
303+
};
304+
305+
let Ok(source_dt) = source_expr.data_type(schema) else {
306+
tracing::debug!(
307+
field_name = field_name,
308+
schema = ?schema,
309+
source_expr = ?source_expr,
310+
"Failed to get source type for GetField, not pushing down"
311+
);
312+
return false;
313+
};
314+
let DataType::Struct(fields) = source_dt else {
315+
tracing::debug!(
316+
field_name = field_name,
317+
schema = ?schema,
318+
source_expr = ?source_expr,
319+
"Failed to get source type as struct for GetField, not pushing down"
320+
);
321+
return false;
322+
};
323+
fields.find(field_name).is_some()
324+
}
325+
280326
#[cfg(test)]
281327
mod tests {
282328
use std::sync::Arc;
@@ -606,8 +652,10 @@ mod tests {
606652
"#);
607653
}
608654

609-
#[test]
610-
fn test_can_be_pushed_down_get_field() {
655+
#[rstest]
656+
#[case::valid_field("field1", true)]
657+
#[case::missing_field("nonexistent_field", false)]
658+
fn test_can_be_pushed_down_get_field(#[case] field_name: &str, #[case] expected: bool) {
611659
let struct_fields = Fields::from(vec![
612660
Field::new("field1", DataType::Utf8, true),
613661
Field::new("field2", DataType::Int32, true),
@@ -619,18 +667,18 @@ mod tests {
619667
)]);
620668

621669
let struct_col = Arc::new(df_expr::Column::new("my_struct", 0)) as Arc<dyn PhysicalExpr>;
622-
let field_name = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
623-
"field1".to_string(),
670+
let field_name_lit = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
671+
field_name.to_string(),
624672
)))) as Arc<dyn PhysicalExpr>;
625673

626674
let get_field_expr = Arc::new(ScalarFunctionExpr::new(
627675
"get_field",
628676
Arc::new(ScalarUDF::from(GetFieldFunc::new())),
629-
vec![struct_col, field_name],
630-
Arc::new(Field::new("field1", DataType::Utf8, true)),
677+
vec![struct_col, field_name_lit],
678+
Arc::new(Field::new(field_name, DataType::Utf8, true)),
631679
Arc::new(ConfigOptions::new()),
632680
)) as Arc<dyn PhysicalExpr>;
633681

634-
assert!(can_be_pushed_down(&get_field_expr, &schema));
682+
assert_eq!(can_be_pushed_down(&get_field_expr, &schema), expected);
635683
}
636684
}

0 commit comments

Comments
 (0)