Skip to content

Commit 4a30403

Browse files
committed
Support Group By Aggregation Queries
1 parent f20d520 commit 4a30403

File tree

1 file changed

+173
-3
lines changed

1 file changed

+173
-3
lines changed

datafusion-uwheel/src/lib.rs

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use uwheel::{
4747
sum::{F64SumAggregator, U32SumAggregator},
4848
},
4949
wheels::read::ReaderWheel,
50-
Aggregator, Conf, Entry, HawConf, RwWheel, WheelRange,
50+
Aggregator, Conf, Duration, Entry, HawConf, RwWheel, WheelRange,
5151
};
5252

5353
/// Custom aggregator implementations that are used by this crate.
@@ -314,6 +314,64 @@ impl UWheelOptimizer {
314314
_ => None,
315315
}
316316
}
317+
318+
LogicalPlan::Aggregate(agg) => {
319+
let group_expr = agg.group_expr.first()?;
320+
321+
// Only continue if the aggregation has a filter
322+
let LogicalPlan::Filter(filter) = agg.input.as_ref() else {
323+
return None;
324+
};
325+
326+
let (wheel_range, _) = extract_filter_expr(&filter.predicate, &self.time_column)?;
327+
328+
match group_expr {
329+
Expr::ScalarFunction(func) if func.name() == "date_trunc" => {
330+
let interval = func.args.first()?;
331+
if let Expr::Literal(ScalarValue::Utf8(duration)) = interval {
332+
match duration.as_ref()?.as_str() {
333+
"second" => {
334+
unimplemented!("date_trunc('second') group by is not supported")
335+
}
336+
"minute" => {
337+
unimplemented!("date_trunc('minute') group by is not supported")
338+
}
339+
"hour" => {
340+
unimplemented!("date_trunc('hour') group by is not supported")
341+
}
342+
"day" => {
343+
let res = self
344+
.wheels
345+
.count
346+
.group_by(wheel_range, Duration::DAY)
347+
.unwrap_or_default()
348+
.iter()
349+
.map(|(k, v)| ((*k * 1_000) as i64, *v as i64)) // transform milliseconds to microseconds by multiplying by 1_000
350+
.collect();
351+
352+
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
353+
354+
return uwheel_group_by_to_table_scan(res, schema).ok();
355+
}
356+
"week" => {
357+
unimplemented!("date_trunc('week') group by is not supported")
358+
}
359+
"month" => {
360+
unimplemented!("date_trunc('month') group by is not supported")
361+
}
362+
"year" => {
363+
unimplemented!("date_trunc('year') group by is not supported")
364+
}
365+
_ => {}
366+
}
367+
}
368+
}
369+
_ => {
370+
unimplemented!("We only support scalar function date_trunc for group by expression now")
371+
}
372+
}
373+
None
374+
}
317375
// Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
318376
LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan),
319377
_ => None,
@@ -440,6 +498,25 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result<LogicalPla
440498
mem_table_as_table_scan(mem_table, df_schema)
441499
}
442500

501+
// Converts a uwheel group by result to a TableScan with a MemTable as source
502+
// currently only supports timestamp group by
503+
fn uwheel_group_by_to_table_scan(
504+
result: Vec<(i64, i64)>,
505+
schema: SchemaRef,
506+
) -> Result<LogicalPlan> {
507+
let group_by =
508+
TimestampMicrosecondArray::from(result.iter().map(|(k, _)| *k).collect::<Vec<_>>());
509+
510+
let agg = Int64Array::from(result.iter().map(|(_, v)| *v).collect::<Vec<_>>());
511+
512+
let record_batch =
513+
RecordBatch::try_new(schema.clone(), vec![Arc::new(group_by), Arc::new(agg)])?;
514+
515+
let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
516+
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
517+
mem_table_as_table_scan(mem_table, df_schema)
518+
}
519+
443520
// helper for possibly removing the table name from the expression key
444521
fn maybe_replace_table_name(expr: &Expr, table_name: &str) -> String {
445522
let expr_str = expr.to_string();
@@ -568,7 +645,7 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
568645
func,
569646
args,
570647
..
571-
} if func.name() == "COUNT" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
648+
} if (func.name() == "COUNT" || func.name() == "count") && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
572649
}
573650

574651
// Helper methods to build the UWheelOptimizer
@@ -942,9 +1019,11 @@ mod tests {
9421019
use chrono::TimeZone;
9431020
use datafusion::arrow::datatypes::{Field, Schema, TimeUnit};
9441021
use datafusion::execution::SessionStateBuilder;
1022+
use datafusion::functions_aggregate::count::count;
9451023
use datafusion::functions_aggregate::expr_fn::avg;
9461024
use datafusion::functions_aggregate::min_max::{max, min};
947-
use datafusion::logical_expr::test::function_stub::{count, sum};
1025+
use datafusion::functions_aggregate::sum::sum;
1026+
use datafusion::prelude::date_trunc;
9481027

9491028
use super::*;
9501029
use builder::Builder;
@@ -1405,4 +1484,95 @@ mod tests {
14051484

14061485
Ok(())
14071486
}
1487+
1488+
#[tokio::test]
1489+
async fn group_by_count_aggregation_rewrite() -> Result<()> {
1490+
let optimizer = test_optimizer().await?;
1491+
1492+
let temporal_filter = col("timestamp")
1493+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1494+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1495+
1496+
let plan =
1497+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1498+
.filter(temporal_filter)?
1499+
.aggregate(
1500+
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
1501+
vec![count(wildcard())],
1502+
)?
1503+
.project(vec![
1504+
date_trunc(lit("day"), col("timestamp")),
1505+
count(wildcard()),
1506+
])?
1507+
.build()?;
1508+
1509+
// Assert that the original plan is a Projection
1510+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1511+
1512+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1513+
// assert it was rewritten to a TableScan
1514+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1515+
1516+
Ok(())
1517+
}
1518+
1519+
#[tokio::test]
1520+
async fn group_by_count_aggregation_exec() -> Result<()> {
1521+
let optimizer = test_optimizer().await?;
1522+
1523+
let temporal_filter = col("timestamp")
1524+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1525+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1526+
1527+
let plan =
1528+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1529+
.filter(temporal_filter)?
1530+
.aggregate(
1531+
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
1532+
vec![count(wildcard())],
1533+
)?
1534+
.project(vec![
1535+
date_trunc(lit("day"), col("timestamp")),
1536+
count(wildcard()),
1537+
])?
1538+
.build()?;
1539+
1540+
let ctx = SessionContext::new();
1541+
ctx.register_table("test", optimizer.provider().clone())?;
1542+
1543+
// Set UWheelOptimizer as optimizer rule
1544+
let session_state = SessionStateBuilder::new()
1545+
.with_optimizer_rules(vec![optimizer.clone()])
1546+
.build();
1547+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1548+
1549+
// Run the query through the ctx that has our OptimizerRule
1550+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1551+
let results = df.collect().await?;
1552+
1553+
assert_eq!(results.len(), 1);
1554+
1555+
assert_eq!(
1556+
results[0]
1557+
.column(0)
1558+
.as_any()
1559+
.downcast_ref::<TimestampMicrosecondArray>()
1560+
.unwrap()
1561+
.value(0)
1562+
/ 1000,
1563+
1_715_299_200_000
1564+
);
1565+
1566+
assert_eq!(
1567+
results[0]
1568+
.column(1)
1569+
.as_any()
1570+
.downcast_ref::<Int64Array>()
1571+
.unwrap()
1572+
.value(0),
1573+
10
1574+
);
1575+
1576+
Ok(())
1577+
}
14081578
}

0 commit comments

Comments
 (0)