@@ -228,8 +228,11 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) ->
228228 } else if let Some ( in_list) = expr. downcast_ref :: < df_expr:: InListExpr > ( ) {
229229 can_be_pushed_down ( in_list. expr ( ) , schema)
230230 && in_list. list ( ) . iter ( ) . all ( |e| can_be_pushed_down ( e, schema) )
231- } else if expr. downcast_ref :: < ScalarFunctionExpr > ( ) . is_some ( ) {
232- get_source_data_type ( df_expr, schema) . is_some ( )
231+ } else if let Some ( scalar_fn) = expr. downcast_ref :: < ScalarFunctionExpr > ( ) {
232+ // Only get_field expressions should be pushed down. Note, we know that
233+ // the GetFieldFunc call should be well-formed, because the DataFusion planner
234+ // checks that for us before we even get to the DataSource.
235+ ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( )
233236 } else {
234237 tracing:: debug!( %df_expr, "DataFusion expression can't be pushed down" ) ;
235238 false
@@ -276,64 +279,37 @@ fn supported_data_types(dt: &DataType) -> bool {
276279 is_supported
277280}
278281
279- /// Evaluate the source `expr` within the scope of `schema` and return its data type. If the source
280- /// expression is not composed of valid field accesses that we can pushdown to Vortex, fail.
281- fn get_source_data_type ( expr : & Arc < dyn PhysicalExpr > , schema : & Schema ) -> Option < DataType > {
282- if let Some ( col) = expr. as_any ( ) . downcast_ref :: < df_expr:: Column > ( ) {
283- // Column expression handler
284- let Ok ( field) = schema. field_with_name ( col. name ( ) ) else {
285- return None ;
286- } ;
287-
288- // Get back the data type here instead.
289- Some ( field. data_type ( ) . clone ( ) )
290- } else if let Some ( scalar_fn) = expr. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( ) {
291- // Struct field access handler
292- let get_field_fn = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) ?;
293-
294- let args = get_field_fn. args ( ) ;
295- if args. len ( ) != 2 {
296- return None ;
297- }
298-
299- let source = & args[ 0 ] ;
300- let field_name_expr = & args[ 1 ] ;
301-
302- let DataType :: Struct ( fields) = get_source_data_type ( source, schema) ? else {
303- return None ;
304- } ;
305-
306- let field_name = field_name_expr
307- . as_any ( )
308- . downcast_ref :: < df_expr:: Literal > ( )
309- . and_then ( |l| l. value ( ) . try_as_str ( ) )
310- . flatten ( ) ?;
311-
312- // Extract the named field from the struct type
313- fields
314- . find ( field_name)
315- . map ( |( _, dt) | dt. data_type ( ) . clone ( ) )
316- } else {
317- None
318- }
319- }
320-
321282#[ cfg( test) ]
322283mod tests {
284+ use std:: any:: Any ;
323285 use std:: sync:: Arc ;
324286
325- use arrow_schema:: { DataType , Field , Fields , Schema , TimeUnit as ArrowTimeUnit } ;
287+ use arrow_schema:: {
288+ DataType , Field , Schema , SchemaBuilder , SchemaRef , TimeUnit as ArrowTimeUnit ,
289+ } ;
326290 use datafusion:: functions:: core:: getfield:: GetFieldFunc ;
327- use datafusion_common :: ScalarValue ;
291+ use datafusion :: logical_expr :: { ColumnarValue , Signature } ;
328292 use datafusion_common:: config:: ConfigOptions ;
329- use datafusion_expr:: { Operator as DFOperator , ScalarUDF } ;
330- use datafusion_physical_expr:: PhysicalExpr ;
293+ use datafusion_common:: { ScalarValue , ToDFSchema } ;
294+ use datafusion_datasource:: file:: FileSource ;
295+ use datafusion_expr:: execution_props:: ExecutionProps ;
296+ use datafusion_expr:: expr:: ScalarFunction ;
297+ use datafusion_expr:: {
298+ Expr , Operator as DFOperator , ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl , Volatility , col,
299+ } ;
300+ use datafusion_functions:: expr_fn:: get_field;
301+ use datafusion_physical_expr:: { PhysicalExpr , create_physical_expr} ;
331302 use datafusion_physical_plan:: expressions as df_expr;
303+ use datafusion_physical_plan:: filter_pushdown:: PushedDown ;
332304 use insta:: assert_snapshot;
333305 use rstest:: rstest;
306+ use vortex:: VortexSessionDefault ;
334307 use vortex:: expr:: { Expression , Operator } ;
308+ use vortex:: session:: VortexSession ;
335309
336310 use super :: * ;
311+ use crate :: VortexSource ;
312+ use crate :: persistent:: cache:: VortexFileCache ;
337313
338314 #[ rstest:: fixture]
339315 fn test_schema ( ) -> Schema {
@@ -505,7 +481,8 @@ mod tests {
505481 DataType :: List ( Arc :: new( Field :: new( "item" , DataType :: Int32 , true ) ) ) ,
506482 false
507483 ) ]
508- #[ case:: struct_type( DataType :: Struct ( vec![ Field :: new( "field" , DataType :: Int32 , true ) ] . into( ) ) , false ) ]
484+ #[ case:: struct_type( DataType :: Struct ( vec![ Field :: new( "field" , DataType :: Int32 , true ) ] . into( )
485+ ) , false ) ]
509486 // Dictionary types - should be supported if value type is supported
510487 #[ case:: dict_utf8(
511488 DataType :: Dictionary ( Box :: new( DataType :: UInt32 ) , Box :: new( DataType :: Utf8 ) ) ,
@@ -647,33 +624,142 @@ mod tests {
647624 "# ) ;
648625 }
649626
650- #[ rstest]
651- #[ case:: valid_field( "field1" , true ) ]
652- #[ case:: missing_field( "nonexistent_field" , false ) ]
653- fn test_can_be_pushed_down_get_field ( #[ case] field_name : & str , #[ case] expected : bool ) {
654- let struct_fields = Fields :: from ( vec ! [
655- Field :: new( "field1" , DataType :: Utf8 , true ) ,
656- Field :: new( "field2" , DataType :: Int32 , true ) ,
657- ] ) ;
658- let schema = Schema :: new ( vec ! [ Field :: new(
659- "my_struct" ,
660- DataType :: Struct ( struct_fields) ,
661- true ,
662- ) ] ) ;
627+ #[ test]
628+ fn test_pushdown_nested_filter ( ) {
629+ // schema:
630+ // a: struct
631+ // |- one: i32
632+ // b:struct
633+ // |- two: i32
634+ let mut test_schema = SchemaBuilder :: new ( ) ;
635+ test_schema. push ( Field :: new_struct (
636+ "a" ,
637+ vec ! [ Field :: new( "one" , DataType :: Int32 , false ) ] ,
638+ false ,
639+ ) ) ;
640+ test_schema. push ( Field :: new_struct (
641+ "b" ,
642+ vec ! [ Field :: new( "two" , DataType :: Int32 , false ) ] ,
643+ false ,
644+ ) ) ;
663645
664- let struct_col = Arc :: new ( df_expr:: Column :: new ( "my_struct" , 0 ) ) as Arc < dyn PhysicalExpr > ;
665- let field_name_lit = Arc :: new ( df_expr:: Literal :: new ( ScalarValue :: Utf8 ( Some (
666- field_name. to_string ( ) ,
667- ) ) ) ) as Arc < dyn PhysicalExpr > ;
646+ let test_schema = Arc :: new ( test_schema. finish ( ) ) ;
647+ // Make sure filter is pushed down
648+ let filter = get_field ( col ( "b" ) , "two" ) . eq ( datafusion_expr:: lit ( 10i32 ) ) ;
668649
669- let get_field_expr = Arc :: new ( ScalarFunctionExpr :: new (
670- "get_field" ,
671- Arc :: new ( ScalarUDF :: from ( GetFieldFunc :: new ( ) ) ) ,
672- vec ! [ struct_col, field_name_lit] ,
673- Arc :: new ( Field :: new ( field_name, DataType :: Utf8 , true ) ) ,
674- Arc :: new ( ConfigOptions :: new ( ) ) ,
675- ) ) as Arc < dyn PhysicalExpr > ;
650+ let df_schema = test_schema. clone ( ) . to_dfschema ( ) . unwrap ( ) ;
651+
652+ let physical_filter =
653+ create_physical_expr ( & filter, & df_schema, & ExecutionProps :: default ( ) ) . unwrap ( ) ;
654+
655+ let source = vortex_source ( & test_schema) ;
656+
657+ let prop = source
658+ . try_pushdown_filters ( vec ! [ physical_filter] , & ConfigOptions :: default ( ) )
659+ . unwrap ( ) ;
660+ assert ! ( matches!( prop. filters[ 0 ] , PushedDown :: Yes ) ) ;
661+ }
662+
663+ #[ test]
664+ fn test_pushdown_deeply_nested_filter ( ) {
665+ // schema:
666+ // a: struct
667+ // |- b: struct
668+ // |- c: i32
669+ let mut schema = SchemaBuilder :: new ( ) ;
670+
671+ let c = Field :: new ( "c" , DataType :: Int32 , false ) ;
672+ let b = Field :: new_struct ( "b" , vec ! [ c] , false ) ;
673+ let a = Field :: new_struct ( "a" , vec ! [ b] , false ) ;
674+ schema. push ( a) ;
675+
676+ let schema = Arc :: new ( schema. finish ( ) ) ;
677+ let df_schema = schema. clone ( ) . to_dfschema ( ) . unwrap ( ) ;
678+
679+ let source = vortex_source ( & schema) ;
680+
681+ let deep_filter = get_field ( get_field ( col ( "a" ) , "b" ) , "c" ) . eq ( datafusion_expr:: lit ( 10i32 ) ) ;
682+
683+ let physical_filter =
684+ create_physical_expr ( & deep_filter, & df_schema, & ExecutionProps :: default ( ) ) . unwrap ( ) ;
685+
686+ let prop = source
687+ . try_pushdown_filters ( vec ! [ physical_filter] , & ConfigOptions :: default ( ) )
688+ . unwrap ( ) ;
689+ assert ! ( matches!( prop. filters[ 0 ] , PushedDown :: Yes ) ) ;
690+ }
691+
692+ #[ test]
693+ fn test_unknown_scalar_function ( ) {
694+ #[ derive( Debug , PartialEq , Eq , Hash ) ]
695+ pub struct UnknownImpl {
696+ signature : Signature ,
697+ }
698+
699+ impl ScalarUDFImpl for UnknownImpl {
700+ fn as_any ( & self ) -> & dyn Any {
701+ self
702+ }
703+
704+ fn name ( & self ) -> & str {
705+ "unknown"
706+ }
707+
708+ fn signature ( & self ) -> & Signature {
709+ & self . signature
710+ }
711+
712+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> datafusion_common:: Result < DataType > {
713+ Ok ( DataType :: Int32 )
714+ }
715+
716+ fn invoke_with_args (
717+ & self ,
718+ _args : ScalarFunctionArgs ,
719+ ) -> datafusion_common:: Result < ColumnarValue > {
720+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 1 ) ) ) )
721+ }
722+ }
723+
724+ // schema:
725+ // a: struct
726+ // |- b: struct
727+ // |- c: i32
728+ let mut schema = SchemaBuilder :: new ( ) ;
729+
730+ let c = Field :: new ( "c" , DataType :: Int32 , false ) ;
731+ let b = Field :: new_struct ( "b" , vec ! [ c] , false ) ;
732+ let a = Field :: new_struct ( "a" , vec ! [ b] , false ) ;
733+ schema. push ( a) ;
734+
735+ let schema = Arc :: new ( schema. finish ( ) ) ;
736+ let df_schema = schema. clone ( ) . to_dfschema ( ) . unwrap ( ) ;
737+
738+ let source = vortex_source ( & schema) ;
739+
740+ let unknown_func = Expr :: ScalarFunction ( ScalarFunction {
741+ func : Arc :: new ( ScalarUDF :: new_from_impl ( UnknownImpl {
742+ signature : Signature :: nullary ( Volatility :: Immutable ) ,
743+ } ) ) ,
744+ args : vec ! [ ] ,
745+ } ) ;
746+
747+ // Another weird ScalarFunction that we can't push down
748+ let deep_filter = unknown_func. eq ( datafusion_expr:: lit ( 10i32 ) ) ;
749+
750+ let physical_filter =
751+ create_physical_expr ( & deep_filter, & df_schema, & ExecutionProps :: default ( ) ) . unwrap ( ) ;
752+
753+ let prop = source
754+ . try_pushdown_filters ( vec ! [ physical_filter] , & ConfigOptions :: default ( ) )
755+ . unwrap ( ) ;
756+ assert ! ( matches!( prop. filters[ 0 ] , PushedDown :: No ) ) ;
757+ }
758+
759+ fn vortex_source ( schema : & SchemaRef ) -> Arc < dyn FileSource > {
760+ let session = VortexSession :: default ( ) ;
761+ let cache = VortexFileCache :: new ( 1024 , 1024 , session. clone ( ) ) ;
676762
677- assert_eq ! ( can_be_pushed_down ( & get_field_expr , & schema) , expected ) ;
763+ Arc :: new ( VortexSource :: new ( session . clone ( ) , cache ) ) . with_schema ( schema. clone ( ) )
678764 }
679765}
0 commit comments