@@ -265,6 +265,16 @@ impl UWheelOptimizer {
265265 agg. group_expr . is_empty ( ) && agg. aggr_expr . len ( ) == 1
266266 }
267267
268+ /// checks whether the Aggregate has a single group_by expression
269+ fn single_group_by ( agg : & Aggregate ) -> bool {
270+ agg. group_expr . len ( ) == 1
271+ }
272+
273+ /// check whether the Aggregate has no group_expr and aggr_expr has a length greater than 1
274+ fn multiple_aggregates ( agg : & Aggregate ) -> bool {
275+ agg. group_expr . is_empty ( ) && agg. aggr_expr . len ( ) > 1
276+ }
277+
268278 // Attemps to rewrite a top-level Projection plan
269279 fn try_rewrite_projection (
270280 & self ,
@@ -315,7 +325,7 @@ impl UWheelOptimizer {
315325 }
316326 }
317327
318- LogicalPlan :: Aggregate ( agg) => {
328+ LogicalPlan :: Aggregate ( agg) if Self :: single_group_by ( agg ) => {
319329 let group_expr = agg. group_expr . first ( ) ?;
320330
321331 // Only continue if the aggregation has a filter
@@ -372,6 +382,57 @@ impl UWheelOptimizer {
372382 }
373383 None
374384 }
385+
386+ LogicalPlan :: Aggregate ( agg) if Self :: multiple_aggregates ( agg) => {
387+ // Only continue if the aggregation has a filter
388+ let LogicalPlan :: Filter ( filter) = agg. input . as_ref ( ) else {
389+ return None ;
390+ } ;
391+
392+ let agg_exprs = & agg. aggr_expr ;
393+
394+ let mut agg_results = Vec :: new ( ) ;
395+
396+ for agg_expr in agg_exprs {
397+ match agg_expr {
398+ // Single Aggregate Function (e.g., SUM(col))
399+ Expr :: AggregateFunction ( agg) if agg. args . len ( ) == 1 => {
400+ if let Expr :: Column ( col) = & agg. args [ 0 ] {
401+ // Fetch temporal filter range and expr key which is used to identify a wheel
402+ let ( range, expr_key) = match extract_filter_expr (
403+ & filter. predicate ,
404+ & self . time_column ,
405+ ) ? {
406+ ( range, Some ( expr) ) => {
407+ ( range, maybe_replace_table_name ( & expr, & self . name ) )
408+ }
409+ ( range, None ) => ( range, STAR_AGGREGATION_ALIAS . to_string ( ) ) ,
410+ } ;
411+
412+ // build the key for the wheel
413+ let wheel_key = format ! ( "{}.{}.{}" , self . name, col. name, expr_key) ;
414+
415+ let agg_type = func_def_to_aggregate_type ( & agg. func ) ?;
416+
417+ // get aggregation result
418+ let result =
419+ self . get_aggregate_result ( agg_type, & wheel_key, range) ?;
420+
421+ agg_results. push ( result) ;
422+ } else {
423+ return None ;
424+ }
425+ }
426+ _ => {
427+ return None ;
428+ }
429+ }
430+ }
431+
432+ let schema = Arc :: new ( plan. schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
433+
434+ uwheel_multiple_aggregations_to_table_scan ( agg_results, schema) . ok ( )
435+ }
375436 // Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
376437 LogicalPlan :: Filter ( filter) => self . try_rewrite_filter ( filter, plan) ,
377438 _ => None ,
@@ -453,27 +514,32 @@ impl UWheelOptimizer {
453514 range : WheelRange ,
454515 schema : SchemaRef ,
455516 ) -> Option < LogicalPlan > {
517+ let result = self . get_aggregate_result ( agg_type, wheel_key, range) ?;
518+ uwheel_agg_to_table_scan ( result, schema) . ok ( )
519+ }
520+
521+ fn get_aggregate_result (
522+ & self ,
523+ agg_type : UWheelAggregate ,
524+ wheel_key : & str ,
525+ range : WheelRange ,
526+ ) -> Option < f64 > {
456527 match agg_type {
457528 UWheelAggregate :: Sum => {
458529 let wheel = self . wheels . sum . lock ( ) . unwrap ( ) . get ( wheel_key) ?. clone ( ) ;
459- let result = wheel. combine_range_and_lower ( range) ?;
460- uwheel_agg_to_table_scan ( result, schema) . ok ( )
530+ wheel. combine_range_and_lower ( range)
461531 }
462532 UWheelAggregate :: Avg => {
463533 let wheel = self . wheels . avg . lock ( ) . unwrap ( ) . get ( wheel_key) ?. clone ( ) ;
464- let result = wheel. combine_range_and_lower ( range) ?;
465-
466- uwheel_agg_to_table_scan ( result, schema) . ok ( )
534+ wheel. combine_range_and_lower ( range)
467535 }
468536 UWheelAggregate :: Min => {
469537 let wheel = self . wheels . min . lock ( ) . unwrap ( ) . get ( wheel_key) ?. clone ( ) ;
470- let result = wheel. combine_range_and_lower ( range) ?;
471- uwheel_agg_to_table_scan ( result, schema) . ok ( )
538+ wheel. combine_range_and_lower ( range)
472539 }
473540 UWheelAggregate :: Max => {
474541 let wheel = self . wheels . max . lock ( ) . unwrap ( ) . get ( wheel_key) ?. clone ( ) ;
475- let result = wheel. combine_range_and_lower ( range) ?;
476- uwheel_agg_to_table_scan ( result, schema) . ok ( )
542+ wheel. combine_range_and_lower ( range)
477543 }
478544 _ => unimplemented ! ( ) ,
479545 }
@@ -517,6 +583,24 @@ fn uwheel_group_by_to_table_scan(
517583 mem_table_as_table_scan ( mem_table, df_schema)
518584}
519585
586+ fn uwheel_multiple_aggregations_to_table_scan (
587+ agg_results : Vec < f64 > ,
588+ schema : SchemaRef ,
589+ ) -> Result < LogicalPlan > {
590+ let mut columns = Vec :: new ( ) ;
591+
592+ for result in agg_results {
593+ let data = Float64Array :: from ( vec ! [ result] ) ;
594+ columns. push ( Arc :: new ( data) as Arc < dyn Array > ) ;
595+ }
596+
597+ let record_batch = RecordBatch :: try_new ( schema. clone ( ) , columns) ?;
598+
599+ let df_schema = Arc :: new ( DFSchema :: try_from ( schema. clone ( ) ) ?) ;
600+ let mem_table = MemTable :: try_new ( schema, vec ! [ vec![ record_batch] ] ) ?;
601+ mem_table_as_table_scan ( mem_table, df_schema)
602+ }
603+
520604// helper for possibly removing the table name from the expression key
521605fn maybe_replace_table_name ( expr : & Expr , table_name : & str ) -> String {
522606 let expr_str = expr. to_string ( ) ;
@@ -1575,4 +1659,116 @@ mod tests {
15751659
15761660 Ok ( ( ) )
15771661 }
1662+
1663+ #[ tokio:: test]
1664+ async fn multiple_aggregation_rewrite ( ) -> Result < ( ) > {
1665+ let optimizer = test_optimizer ( ) . await ?;
1666+
1667+ optimizer
1668+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1669+ "agg_col" ,
1670+ UWheelAggregate :: Avg ,
1671+ ) )
1672+ . await ?;
1673+
1674+ optimizer
1675+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1676+ "agg_col" ,
1677+ UWheelAggregate :: Sum ,
1678+ ) )
1679+ . await ?;
1680+
1681+ let temporal_filter = col ( "timestamp" )
1682+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1683+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1684+
1685+ let plan =
1686+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1687+ . filter ( temporal_filter) ?
1688+ . aggregate (
1689+ Vec :: < Expr > :: new ( ) ,
1690+ vec ! [ avg( col( "agg_col" ) ) , sum( col( "agg_col" ) ) ] ,
1691+ ) ?
1692+ . project ( vec ! [ avg( col( "agg_col" ) ) , sum( col( "agg_col" ) ) ] ) ?
1693+ . build ( ) ?;
1694+
1695+ // Assert that the original plan is a Projection
1696+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1697+
1698+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1699+ // assert it was rewritten to a TableScan
1700+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1701+
1702+ Ok ( ( ) )
1703+ }
1704+
1705+ #[ tokio:: test]
1706+ async fn multiple_aggregation_exec ( ) -> Result < ( ) > {
1707+ let optimizer = test_optimizer ( ) . await ?;
1708+
1709+ optimizer
1710+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1711+ "agg_col" ,
1712+ UWheelAggregate :: Avg ,
1713+ ) )
1714+ . await ?;
1715+
1716+ optimizer
1717+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1718+ "agg_col" ,
1719+ UWheelAggregate :: Sum ,
1720+ ) )
1721+ . await ?;
1722+
1723+ let temporal_filter = col ( "timestamp" )
1724+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1725+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1726+
1727+ let plan =
1728+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1729+ . filter ( temporal_filter) ?
1730+ . aggregate (
1731+ Vec :: < Expr > :: new ( ) ,
1732+ vec ! [ avg( col( "agg_col" ) ) , sum( col( "agg_col" ) ) ] ,
1733+ ) ?
1734+ . project ( vec ! [ avg( col( "agg_col" ) ) , sum( col( "agg_col" ) ) ] ) ?
1735+ . build ( ) ?;
1736+
1737+ let ctx = SessionContext :: new ( ) ;
1738+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1739+
1740+ // Set UWheelOptimizer as optimizer rule
1741+ let session_state = SessionStateBuilder :: new ( )
1742+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1743+ . build ( ) ;
1744+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1745+
1746+ // Run the query through the ctx that has our OptimizerRule
1747+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1748+ let results = df. collect ( ) . await ?;
1749+
1750+ assert_eq ! ( results. len( ) , 1 ) ;
1751+
1752+ assert_eq ! (
1753+ results[ 0 ]
1754+ . column( 0 )
1755+ . as_any( )
1756+ . downcast_ref:: <Float64Array >( )
1757+ . unwrap( )
1758+ . value( 0 ) ,
1759+ 5.5
1760+ ) ;
1761+
1762+ assert_eq ! (
1763+ results[ 0 ]
1764+ . column( 1 )
1765+ . as_any( )
1766+ . downcast_ref:: <Float64Array >( )
1767+ . unwrap( )
1768+ . value( 0 ) ,
1769+ 55.0
1770+ ) ;
1771+
1772+ Ok ( ( ) )
1773+ }
15781774}
0 commit comments