@@ -47,7 +47,7 @@ use uwheel::{
4747 sum:: { F64SumAggregator , U32SumAggregator } ,
4848 } ,
4949 wheels:: read:: ReaderWheel ,
50- Aggregator , Conf , Entry , HawConf , RwWheel , WheelRange ,
50+ Aggregator , Conf , Duration , Entry , HawConf , RwWheel , WheelRange ,
5151} ;
5252
5353/// Custom aggregator implementations that are used by this crate.
@@ -314,6 +314,64 @@ impl UWheelOptimizer {
314314 _ => None ,
315315 }
316316 }
317+
318+ LogicalPlan :: Aggregate ( agg) => {
319+ let group_expr = agg. group_expr . first ( ) ?;
320+
321+ // Only continue if the aggregation has a filter
322+ let LogicalPlan :: Filter ( filter) = agg. input . as_ref ( ) else {
323+ return None ;
324+ } ;
325+
326+ let ( wheel_range, _) = extract_filter_expr ( & filter. predicate , & self . time_column ) ?;
327+
328+ match group_expr {
329+ Expr :: ScalarFunction ( func) if func. name ( ) == "date_trunc" => {
330+ let interval = func. args . first ( ) ?;
331+ if let Expr :: Literal ( ScalarValue :: Utf8 ( duration) ) = interval {
332+ match duration. as_ref ( ) ?. as_str ( ) {
333+ "second" => {
334+ unimplemented ! ( "date_trunc('second') group by is not supported" )
335+ }
336+ "minute" => {
337+ unimplemented ! ( "date_trunc('minute') group by is not supported" )
338+ }
339+ "hour" => {
340+ unimplemented ! ( "date_trunc('hour') group by is not supported" )
341+ }
342+ "day" => {
343+ let res = self
344+ . wheels
345+ . count
346+ . group_by ( wheel_range, Duration :: DAY )
347+ . unwrap_or_default ( )
348+ . iter ( )
349+ . map ( |( k, v) | ( ( * k * 1_000 ) as i64 , * v as i64 ) ) // transform milliseconds to microseconds by multiplying by 1_000
350+ . collect ( ) ;
351+
352+ let schema = Arc :: new ( plan. schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
353+
354+ return uwheel_group_by_to_table_scan ( res, schema) . ok ( ) ;
355+ }
356+ "week" => {
357+ unimplemented ! ( "date_trunc('week') group by is not supported" )
358+ }
359+ "month" => {
360+ unimplemented ! ( "date_trunc('month') group by is not supported" )
361+ }
362+ "year" => {
363+ unimplemented ! ( "date_trunc('year') group by is not supported" )
364+ }
365+ _ => { }
366+ }
367+ }
368+ }
369+ _ => {
370+ unimplemented ! ( "We only support scalar function date_trunc for group by expression now" )
371+ }
372+ }
373+ None
374+ }
317375 // Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
318376 LogicalPlan :: Filter ( filter) => self . try_rewrite_filter ( filter, plan) ,
319377 _ => None ,
@@ -440,6 +498,25 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result<LogicalPla
440498 mem_table_as_table_scan ( mem_table, df_schema)
441499}
442500
501+ // Converts a uwheel group by result to a TableScan with a MemTable as source
502+ // currently only supports timestamp group by
503+ fn uwheel_group_by_to_table_scan (
504+ result : Vec < ( i64 , i64 ) > ,
505+ schema : SchemaRef ,
506+ ) -> Result < LogicalPlan > {
507+ let group_by =
508+ TimestampMicrosecondArray :: from ( result. iter ( ) . map ( |( k, _) | * k) . collect :: < Vec < _ > > ( ) ) ;
509+
510+ let agg = Int64Array :: from ( result. iter ( ) . map ( |( _, v) | * v) . collect :: < Vec < _ > > ( ) ) ;
511+
512+ let record_batch =
513+ RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( group_by) , Arc :: new( agg) ] ) ?;
514+
515+ let df_schema = Arc :: new ( DFSchema :: try_from ( schema. clone ( ) ) ?) ;
516+ let mem_table = MemTable :: try_new ( schema, vec ! [ vec![ record_batch] ] ) ?;
517+ mem_table_as_table_scan ( mem_table, df_schema)
518+ }
519+
443520// helper for possibly removing the table name from the expression key
444521fn maybe_replace_table_name ( expr : & Expr , table_name : & str ) -> String {
445522 let expr_str = expr. to_string ( ) ;
@@ -568,7 +645,7 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
568645 func,
569646 args,
570647 ..
571- } if func. name( ) == "COUNT" && ( args. len( ) == 1 && is_wildcard( & args[ 0 ] ) || args. is_empty( ) ) )
648+ } if ( func. name( ) == "COUNT" || func . name ( ) == "count" ) && ( args. len( ) == 1 && is_wildcard( & args[ 0 ] ) || args. is_empty( ) ) )
572649}
573650
574651// Helper methods to build the UWheelOptimizer
@@ -942,9 +1019,11 @@ mod tests {
9421019 use chrono:: TimeZone ;
9431020 use datafusion:: arrow:: datatypes:: { Field , Schema , TimeUnit } ;
9441021 use datafusion:: execution:: SessionStateBuilder ;
1022+ use datafusion:: functions_aggregate:: count:: count;
9451023 use datafusion:: functions_aggregate:: expr_fn:: avg;
9461024 use datafusion:: functions_aggregate:: min_max:: { max, min} ;
947- use datafusion:: logical_expr:: test:: function_stub:: { count, sum} ;
1025+ use datafusion:: functions_aggregate:: sum:: sum;
1026+ use datafusion:: prelude:: date_trunc;
9481027
9491028 use super :: * ;
9501029 use builder:: Builder ;
@@ -1405,4 +1484,95 @@ mod tests {
14051484
14061485 Ok ( ( ) )
14071486 }
1487+
1488+ #[ tokio:: test]
1489+ async fn group_by_count_aggregation_rewrite ( ) -> Result < ( ) > {
1490+ let optimizer = test_optimizer ( ) . await ?;
1491+
1492+ let temporal_filter = col ( "timestamp" )
1493+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1494+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1495+
1496+ let plan =
1497+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1498+ . filter ( temporal_filter) ?
1499+ . aggregate (
1500+ vec ! [ date_trunc( lit( "day" ) , col( "timestamp" ) ) ] , // GROUP BY date_trunc('day', timestamp)
1501+ vec ! [ count( wildcard( ) ) ] ,
1502+ ) ?
1503+ . project ( vec ! [
1504+ date_trunc( lit( "day" ) , col( "timestamp" ) ) ,
1505+ count( wildcard( ) ) ,
1506+ ] ) ?
1507+ . build ( ) ?;
1508+
1509+ // Assert that the original plan is a Projection
1510+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1511+
1512+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1513+ // assert it was rewritten to a TableScan
1514+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1515+
1516+ Ok ( ( ) )
1517+ }
1518+
1519+ #[ tokio:: test]
1520+ async fn group_by_count_aggregation_exec ( ) -> Result < ( ) > {
1521+ let optimizer = test_optimizer ( ) . await ?;
1522+
1523+ let temporal_filter = col ( "timestamp" )
1524+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1525+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1526+
1527+ let plan =
1528+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1529+ . filter ( temporal_filter) ?
1530+ . aggregate (
1531+ vec ! [ date_trunc( lit( "day" ) , col( "timestamp" ) ) ] , // GROUP BY date_trunc('day', timestamp)
1532+ vec ! [ count( wildcard( ) ) ] ,
1533+ ) ?
1534+ . project ( vec ! [
1535+ date_trunc( lit( "day" ) , col( "timestamp" ) ) ,
1536+ count( wildcard( ) ) ,
1537+ ] ) ?
1538+ . build ( ) ?;
1539+
1540+ let ctx = SessionContext :: new ( ) ;
1541+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1542+
1543+ // Set UWheelOptimizer as optimizer rule
1544+ let session_state = SessionStateBuilder :: new ( )
1545+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
1546+ . build ( ) ;
1547+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1548+
1549+ // Run the query through the ctx that has our OptimizerRule
1550+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1551+ let results = df. collect ( ) . await ?;
1552+
1553+ assert_eq ! ( results. len( ) , 1 ) ;
1554+
1555+ assert_eq ! (
1556+ results[ 0 ]
1557+ . column( 0 )
1558+ . as_any( )
1559+ . downcast_ref:: <TimestampMicrosecondArray >( )
1560+ . unwrap( )
1561+ . value( 0 )
1562+ / 1000 ,
1563+ 1_715_299_200_000
1564+ ) ;
1565+
1566+ assert_eq ! (
1567+ results[ 0 ]
1568+ . column( 1 )
1569+ . as_any( )
1570+ . downcast_ref:: <Int64Array >( )
1571+ . unwrap( )
1572+ . value( 0 ) ,
1573+ 10
1574+ ) ;
1575+
1576+ Ok ( ( ) )
1577+ }
14081578}
0 commit comments