@@ -229,8 +229,7 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) ->
229229 can_be_pushed_down ( in_list. expr ( ) , schema)
230230 && in_list. list ( ) . iter ( ) . all ( |e| can_be_pushed_down ( e, schema) )
231231 } else if let Some ( scalar_fn) = expr. downcast_ref :: < ScalarFunctionExpr > ( ) {
232- // Only get_field pushdown is supported.
233- ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( )
232+ can_scalar_fn_be_pushed_down ( scalar_fn, schema)
234233 } else {
235234 tracing:: debug!( %df_expr, "DataFusion expression can't be pushed down" ) ;
236235 false
@@ -277,6 +276,53 @@ fn supported_data_types(dt: &DataType) -> bool {
277276 is_supported
278277}
279278
279+ /// Checks if a GetField scalar function can be pushed down.
280+ fn can_scalar_fn_be_pushed_down ( scalar_fn : & ScalarFunctionExpr , schema : & Schema ) -> bool {
281+ let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn)
282+ else {
283+ // Only get_field pushdown is supported.
284+ return false ;
285+ } ;
286+
287+ let args = get_field_fn. args ( ) ;
288+ if args. len ( ) != 2 {
289+ tracing:: debug!(
290+ "Expected 2 arguments for GetField, not pushing down {} arguments" ,
291+ args. len( )
292+ ) ;
293+ return false ;
294+ }
295+ let source_expr = & args[ 0 ] ;
296+ let field_name_expr = & args[ 1 ] ;
297+ let Some ( field_name) = field_name_expr
298+ . as_any ( )
299+ . downcast_ref :: < df_expr:: Literal > ( )
300+ . and_then ( |lit| lit. value ( ) . try_as_str ( ) . flatten ( ) )
301+ else {
302+ return false ;
303+ } ;
304+
305+ let Ok ( source_dt) = source_expr. data_type ( schema) else {
306+ tracing:: debug!(
307+ field_name = field_name,
308+ schema = ?schema,
309+ source_expr = ?source_expr,
310+ "Failed to get source type for GetField, not pushing down"
311+ ) ;
312+ return false ;
313+ } ;
314+ let DataType :: Struct ( fields) = source_dt else {
315+ tracing:: debug!(
316+ field_name = field_name,
317+ schema = ?schema,
318+ source_expr = ?source_expr,
319+ "Failed to get source type as struct for GetField, not pushing down"
320+ ) ;
321+ return false ;
322+ } ;
323+ fields. find ( field_name) . is_some ( )
324+ }
325+
280326#[ cfg( test) ]
281327mod tests {
282328 use std:: sync:: Arc ;
@@ -606,8 +652,10 @@ mod tests {
606652 "# ) ;
607653 }
608654
609- #[ test]
610- fn test_can_be_pushed_down_get_field ( ) {
655+ #[ rstest]
656+ #[ case:: valid_field( "field1" , true ) ]
657+ #[ case:: missing_field( "nonexistent_field" , false ) ]
658+ fn test_can_be_pushed_down_get_field ( #[ case] field_name : & str , #[ case] expected : bool ) {
611659 let struct_fields = Fields :: from ( vec ! [
612660 Field :: new( "field1" , DataType :: Utf8 , true ) ,
613661 Field :: new( "field2" , DataType :: Int32 , true ) ,
@@ -619,18 +667,18 @@ mod tests {
619667 ) ] ) ;
620668
621669 let struct_col = Arc :: new ( df_expr:: Column :: new ( "my_struct" , 0 ) ) as Arc < dyn PhysicalExpr > ;
622- let field_name = Arc :: new ( df_expr:: Literal :: new ( ScalarValue :: Utf8 ( Some (
623- "field1" . to_string ( ) ,
670+ let field_name_lit = Arc :: new ( df_expr:: Literal :: new ( ScalarValue :: Utf8 ( Some (
671+ field_name . to_string ( ) ,
624672 ) ) ) ) as Arc < dyn PhysicalExpr > ;
625673
626674 let get_field_expr = Arc :: new ( ScalarFunctionExpr :: new (
627675 "get_field" ,
628676 Arc :: new ( ScalarUDF :: from ( GetFieldFunc :: new ( ) ) ) ,
629- vec ! [ struct_col, field_name ] ,
630- Arc :: new ( Field :: new ( "field1" , DataType :: Utf8 , true ) ) ,
677+ vec ! [ struct_col, field_name_lit ] ,
678+ Arc :: new ( Field :: new ( field_name , DataType :: Utf8 , true ) ) ,
631679 Arc :: new ( ConfigOptions :: new ( ) ) ,
632680 ) ) as Arc < dyn PhysicalExpr > ;
633681
634- assert ! ( can_be_pushed_down( & get_field_expr, & schema) ) ;
682+ assert_eq ! ( can_be_pushed_down( & get_field_expr, & schema) , expected ) ;
635683 }
636684}
0 commit comments