@@ -5,7 +5,8 @@ use std::sync::Arc;
55
66use arrow_schema:: { DataType , Schema } ;
77use datafusion_expr:: Operator as DFOperator ;
8- use datafusion_physical_expr:: { PhysicalExpr , PhysicalExprRef } ;
8+ use datafusion_functions:: core:: getfield:: GetFieldFunc ;
9+ use datafusion_physical_expr:: { PhysicalExpr , PhysicalExprRef , ScalarFunctionExpr } ;
910use datafusion_physical_expr_common:: physical_expr:: is_dynamic_physical_expr;
1011use datafusion_physical_plan:: expressions as df_expr;
1112use itertools:: Itertools ;
@@ -104,10 +105,47 @@ impl TryFromDataFusion<dyn PhysicalExpr> for ExprRef {
104105 return Ok ( if in_list. negated ( ) { not ( expr) } else { expr } ) ;
105106 }
106107
108+ if let Some ( scalar_fn) = df. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( ) {
109+ return try_convert_scalar_function ( scalar_fn) ;
110+ }
111+
107112 vortex_bail ! ( "Couldn't convert DataFusion physical {df} expression to a vortex expression" )
108113 }
109114}
110115
116+ /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
117+ fn try_convert_scalar_function ( scalar_fn : & ScalarFunctionExpr ) -> VortexResult < ExprRef > {
118+ if let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) {
119+ let source_expr = get_field_fn
120+ . args ( )
121+ . first ( )
122+ . ok_or_else ( || vortex_err ! ( "get_field missing source expression" ) ) ?
123+ . as_ref ( ) ;
124+ let field_name_expr = get_field_fn
125+ . args ( )
126+ . get ( 1 )
127+ . ok_or_else ( || vortex_err ! ( "get_field missing field name argument" ) ) ?;
128+ let field_name = field_name_expr
129+ . as_any ( )
130+ . downcast_ref :: < df_expr:: Literal > ( )
131+ . ok_or_else ( || vortex_err ! ( "get_field field name must be a literal" ) ) ?
132+ . value ( )
133+ . try_as_str ( )
134+ . flatten ( )
135+ . ok_or_else ( || vortex_err ! ( "get_field field name must be a UTF-8 string" ) ) ?;
136+ return Ok ( get_item (
137+ field_name. to_string ( ) ,
138+ ExprRef :: try_from_df ( source_expr) ?,
139+ ) ) ;
140+ }
141+
142+ tracing:: debug!(
143+ function_name = scalar_fn. name( ) ,
144+ "Unsupported ScalarFunctionExpr"
145+ ) ;
146+ vortex_bail ! ( "Unsupported ScalarFunctionExpr: {}" , scalar_fn. name( ) )
147+ }
148+
111149impl TryFromDataFusion < DFOperator > for Operator {
112150 fn try_from_df ( value : & DFOperator ) -> VortexResult < Self > {
113151 match value {
@@ -188,6 +226,9 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) ->
188226 } else if let Some ( in_list) = expr. downcast_ref :: < df_expr:: InListExpr > ( ) {
189227 can_be_pushed_down ( in_list. expr ( ) , schema)
190228 && in_list. list ( ) . iter ( ) . all ( |e| can_be_pushed_down ( e, schema) )
229+ } else if let Some ( scalar_fn) = expr. downcast_ref :: < ScalarFunctionExpr > ( ) {
230+ // Only get_field pushdown is supported.
231+ ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( )
191232 } else {
192233 tracing:: debug!( %df_expr, "DataFusion expression can't be pushed down" ) ;
193234 false
@@ -203,6 +244,12 @@ fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> b
203244
204245fn supported_data_types ( dt : & DataType ) -> bool {
205246 use DataType :: * ;
247+
248+ // For dictionary types, check if the value type is supported.
249+ if let Dictionary ( _, value_type) = dt {
250+ return supported_data_types ( value_type. as_ref ( ) ) ;
251+ }
252+
206253 let is_supported = dt. is_null ( )
207254 || dt. is_numeric ( )
208255 || matches ! (
@@ -232,9 +279,11 @@ fn supported_data_types(dt: &DataType) -> bool {
232279mod tests {
233280 use std:: sync:: Arc ;
234281
235- use arrow_schema:: { DataType , Field , Schema , TimeUnit as ArrowTimeUnit } ;
282+ use arrow_schema:: { DataType , Field , Fields , Schema , TimeUnit as ArrowTimeUnit } ;
283+ use datafusion:: functions:: core:: getfield:: GetFieldFunc ;
236284 use datafusion_common:: ScalarValue ;
237- use datafusion_expr:: Operator as DFOperator ;
285+ use datafusion_common:: config:: ConfigOptions ;
286+ use datafusion_expr:: { Operator as DFOperator , ScalarUDF } ;
238287 use datafusion_physical_expr:: PhysicalExpr ;
239288 use datafusion_physical_plan:: expressions as df_expr;
240289 use insta:: assert_snapshot;
@@ -415,6 +464,22 @@ mod tests {
415464 false
416465 ) ]
417466 #[ case:: struct_type( DataType :: Struct ( vec![ Field :: new( "field" , DataType :: Int32 , true ) ] . into( ) ) , false ) ]
467+ // Dictionary types - should be supported if value type is supported
468+ #[ case:: dict_utf8(
469+ DataType :: Dictionary ( Box :: new( DataType :: UInt32 ) , Box :: new( DataType :: Utf8 ) ) ,
470+ true
471+ ) ]
472+ #[ case:: dict_int32(
473+ DataType :: Dictionary ( Box :: new( DataType :: UInt32 ) , Box :: new( DataType :: Int32 ) ) ,
474+ true
475+ ) ]
476+ #[ case:: dict_unsupported(
477+ DataType :: Dictionary (
478+ Box :: new( DataType :: UInt32 ) ,
479+ Box :: new( DataType :: List ( Arc :: new( Field :: new( "item" , DataType :: Int32 , true ) ) ) )
480+ ) ,
481+ false
482+ ) ]
418483 fn test_supported_data_types ( #[ case] data_type : DataType , #[ case] expected : bool ) {
419484 assert_eq ! ( supported_data_types( & data_type) , expected) ;
420485 }
@@ -518,4 +583,53 @@ mod tests {
518583
519584 assert ! ( !can_be_pushed_down( & like_expr, & test_schema) ) ;
520585 }
586+
587+ #[ test]
588+ fn test_expr_from_df_get_field ( ) {
589+ let struct_col = Arc :: new ( df_expr:: Column :: new ( "my_struct" , 0 ) ) as Arc < dyn PhysicalExpr > ;
590+ let field_name = Arc :: new ( df_expr:: Literal :: new ( ScalarValue :: Utf8 ( Some (
591+ "field1" . to_string ( ) ,
592+ ) ) ) ) as Arc < dyn PhysicalExpr > ;
593+ let get_field_expr = ScalarFunctionExpr :: new (
594+ "get_field" ,
595+ Arc :: new ( ScalarUDF :: from ( GetFieldFunc :: new ( ) ) ) ,
596+ vec ! [ struct_col, field_name] ,
597+ Arc :: new ( Field :: new ( "field1" , DataType :: Utf8 , true ) ) ,
598+ Arc :: new ( ConfigOptions :: new ( ) ) ,
599+ ) ;
600+ let result = ExprRef :: try_from_df ( & get_field_expr) . unwrap ( ) ;
601+ assert_snapshot ! ( result. display_tree( ) . to_string( ) , @r"
602+ GetItem(field1)
603+ └── GetItem(my_struct)
604+ └── Root
605+ " ) ;
606+ }
607+
608+ #[ test]
609+ fn test_can_be_pushed_down_get_field ( ) {
610+ let struct_fields = Fields :: from ( vec ! [
611+ Field :: new( "field1" , DataType :: Utf8 , true ) ,
612+ Field :: new( "field2" , DataType :: Int32 , true ) ,
613+ ] ) ;
614+ let schema = Schema :: new ( vec ! [ Field :: new(
615+ "my_struct" ,
616+ DataType :: Struct ( struct_fields) ,
617+ true ,
618+ ) ] ) ;
619+
620+ let struct_col = Arc :: new ( df_expr:: Column :: new ( "my_struct" , 0 ) ) as Arc < dyn PhysicalExpr > ;
621+ let field_name = Arc :: new ( df_expr:: Literal :: new ( ScalarValue :: Utf8 ( Some (
622+ "field1" . to_string ( ) ,
623+ ) ) ) ) as Arc < dyn PhysicalExpr > ;
624+
625+ let get_field_expr = Arc :: new ( ScalarFunctionExpr :: new (
626+ "get_field" ,
627+ Arc :: new ( ScalarUDF :: from ( GetFieldFunc :: new ( ) ) ) ,
628+ vec ! [ struct_col, field_name] ,
629+ Arc :: new ( Field :: new ( "field1" , DataType :: Utf8 , true ) ) ,
630+ Arc :: new ( ConfigOptions :: new ( ) ) ,
631+ ) ) as Arc < dyn PhysicalExpr > ;
632+
633+ assert ! ( can_be_pushed_down( & get_field_expr, & schema) ) ;
634+ }
521635}
0 commit comments