77
88use std:: {
99 collections:: HashMap ,
10+ fmt:: Debug ,
1011 sync:: { Arc , Mutex } ,
1112} ;
1213
1314use chrono:: { DateTime , NaiveDate , Utc } ;
15+ use datafusion:: error:: Result ;
16+ use datafusion:: prelude:: * ;
1417use datafusion:: {
1518 arrow:: {
1619 array:: {
@@ -24,15 +27,13 @@ use datafusion::{
2427 datasource:: { provider_as_source, MemTable , TableProvider } ,
2528 error:: DataFusionError ,
2629 logical_expr:: {
27- expr:: AggregateFunctionDefinition , Aggregate , Filter , LogicalPlan , LogicalPlanBuilder ,
30+ expr:: AggregateFunction , Aggregate , AggregateUDF , Filter , LogicalPlan , LogicalPlanBuilder ,
2831 Operator , Projection , TableScan ,
2932 } ,
3033 optimizer:: { optimizer:: ApplyOrder , OptimizerConfig , OptimizerRule } ,
31- prelude:: * ,
3234 scalar:: ScalarValue ,
3335 sql:: TableReference ,
3436} ;
35- use datafusion:: { error:: Result , logical_expr:: expr:: AggregateFunction } ;
3637use expr:: {
3738 extract_filter_expr, extract_uwheel_expr, extract_wheel_range, MinMaxFilter , UWheelExpr ,
3839} ;
@@ -303,7 +304,7 @@ impl UWheelOptimizer {
303304 // build the key for the wheel
304305 let wheel_key = format ! ( "{}.{}.{}" , self . name, col. name, expr_key) ;
305306
306- let agg_type = func_def_to_aggregate_type ( & agg. func_def ) ?;
307+ let agg_type = func_def_to_aggregate_type ( & agg. func ) ?;
307308 let schema = Arc :: new ( plan. schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
308309 self . create_uwheel_plan ( agg_type, & wheel_key, range, schema)
309310 } else {
@@ -483,23 +484,23 @@ fn empty_table_scan(
483484 LogicalPlanBuilder :: scan ( table_ref. into ( ) , source, None ) ?. build ( )
484485}
485486
486- fn func_def_to_aggregate_type ( func_def : & AggregateFunctionDefinition ) -> Option < UWheelAggregate > {
487- match func_def {
488- AggregateFunctionDefinition :: BuiltIn ( datafusion:: logical_expr:: AggregateFunction :: Max ) => {
489- Some ( UWheelAggregate :: Max )
490- }
491- AggregateFunctionDefinition :: BuiltIn ( datafusion:: logical_expr:: AggregateFunction :: Min ) => {
492- Some ( UWheelAggregate :: Min )
493- }
494- AggregateFunctionDefinition :: UDF ( udf) if udf. name ( ) == "avg" => Some ( UWheelAggregate :: Avg ) ,
495- AggregateFunctionDefinition :: UDF ( udf) if udf. name ( ) == "sum" => Some ( UWheelAggregate :: Sum ) ,
496- AggregateFunctionDefinition :: UDF ( udf) if udf. name ( ) == "count" => {
497- Some ( UWheelAggregate :: Count )
498- }
487+ fn func_def_to_aggregate_type ( func_def : & Arc < AggregateUDF > ) -> Option < UWheelAggregate > {
488+ match func_def. name ( ) {
489+ "max" => Some ( UWheelAggregate :: Max ) ,
490+ "min" => Some ( UWheelAggregate :: Min ) ,
491+ "avg" => Some ( UWheelAggregate :: Avg ) ,
492+ "sum" => Some ( UWheelAggregate :: Sum ) ,
493+ "count" => Some ( UWheelAggregate :: Count ) ,
499494 _ => None ,
500495 }
501496}
502497
498+ impl Debug for UWheelOptimizer {
499+ fn fmt ( & self , _f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
500+ Ok ( ( ) )
501+ }
502+ }
503+
503504impl OptimizerRule for UWheelOptimizer {
504505 fn name ( & self ) -> & str {
505506 "uwheel_optimizer_rewriter"
@@ -541,7 +542,13 @@ fn mem_table_as_table_scan(table: MemTable, original_schema: DFSchemaRef) -> Res
541542}
542543
543544fn is_wildcard ( expr : & Expr ) -> bool {
544- matches ! ( expr, Expr :: Wildcard { qualifier: None } )
545+ matches ! (
546+ expr,
547+ Expr :: Wildcard {
548+ qualifier: None ,
549+ ..
550+ }
551+ )
545552}
546553
547554/// Determines if the given aggregate function is a COUNT(*) aggregate.
@@ -558,10 +565,10 @@ fn is_wildcard(expr: &Expr) -> bool {
558565fn is_count_star_aggregate ( aggregate_function : & AggregateFunction ) -> bool {
559566 matches ! ( aggregate_function,
560567 AggregateFunction {
561- func_def ,
568+ func ,
562569 args,
563570 ..
564- } if func_def . name( ) == "COUNT" && ( args. len( ) == 1 && is_wildcard( & args[ 0 ] ) || args. is_empty( ) ) )
571+ } if func . name( ) == "COUNT" && ( args. len( ) == 1 && is_wildcard( & args[ 0 ] ) || args. is_empty( ) ) )
565572}
566573
567574// Helper methods to build the UWheelOptimizer
@@ -934,7 +941,9 @@ mod tests {
934941 use chrono:: Duration ;
935942 use chrono:: TimeZone ;
936943 use datafusion:: arrow:: datatypes:: { Field , Schema , TimeUnit } ;
944+ use datafusion:: execution:: SessionStateBuilder ;
937945 use datafusion:: functions_aggregate:: expr_fn:: avg;
946+ use datafusion:: functions_aggregate:: min_max:: { max, min} ;
938947 use datafusion:: logical_expr:: test:: function_stub:: { count, sum} ;
939948
940949 use super :: * ;
@@ -1182,7 +1191,9 @@ mod tests {
11821191 ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
11831192
11841193 // Set UWheelOptimizer as optimizer rule
1185- let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1194+ let session_state = SessionStateBuilder :: new ( )
1195+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1196+ . build ( ) ;
11861197 let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
11871198
11881199 // Run the query through the ctx that has our OptimizerRule
@@ -1228,7 +1239,9 @@ mod tests {
12281239 ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
12291240
12301241 // Set UWheelOptimizer as optimizer rule
1231- let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1242+ let session_state = SessionStateBuilder :: new ( )
1243+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1244+ . build ( ) ;
12321245 let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
12331246
12341247 // Run the query through the ctx that has our OptimizerRule
@@ -1274,7 +1287,9 @@ mod tests {
12741287 ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
12751288
12761289 // Set UWheelOptimizer as optimizer rule
1277- let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1290+ let session_state = SessionStateBuilder :: new ( )
1291+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1292+ . build ( ) ;
12781293 let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
12791294
12801295 // Run the query through the ctx that has our OptimizerRule
@@ -1320,7 +1335,9 @@ mod tests {
13201335 ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
13211336
13221337 // Set UWheelOptimizer as optimizer rule
1323- let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1338+ let session_state = SessionStateBuilder :: new ( )
1339+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1340+ . build ( ) ;
13241341 let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
13251342
13261343 // Run the query through the ctx that has our OptimizerRule
@@ -1366,7 +1383,9 @@ mod tests {
13661383 ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
13671384
13681385 // Set UWheelOptimizer as optimizer rule
1369- let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1386+ let session_state = SessionStateBuilder :: new ( )
1387+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1388+ . build ( ) ;
13701389 let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
13711390
13721391 // Run the query through the ctx that has our OptimizerRule
0 commit comments