Skip to content

Commit c486fec

Browse files
authored
Merge pull request #23 from LYZJU2019/multiple-agg
Add support for multiple aggregations
2 parents 261b601 + 46ab697 commit c486fec

File tree

1 file changed

+206
-10
lines changed

1 file changed

+206
-10
lines changed

datafusion-uwheel/src/lib.rs

Lines changed: 206 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,16 @@ impl UWheelOptimizer {
265265
agg.group_expr.is_empty() && agg.aggr_expr.len() == 1
266266
}
267267

268+
/// checks whether the Aggregate has a single group_by expression
269+
fn single_group_by(agg: &Aggregate) -> bool {
270+
agg.group_expr.len() == 1
271+
}
272+
273+
/// check whether the Aggregate has no group_expr and aggr_expr has a length greater than 1
274+
fn multiple_aggregates(agg: &Aggregate) -> bool {
275+
agg.group_expr.is_empty() && agg.aggr_expr.len() > 1
276+
}
277+
268278
// Attemps to rewrite a top-level Projection plan
269279
fn try_rewrite_projection(
270280
&self,
@@ -315,7 +325,7 @@ impl UWheelOptimizer {
315325
}
316326
}
317327

318-
LogicalPlan::Aggregate(agg) => {
328+
LogicalPlan::Aggregate(agg) if Self::single_group_by(agg) => {
319329
let group_expr = agg.group_expr.first()?;
320330

321331
// Only continue if the aggregation has a filter
@@ -372,6 +382,57 @@ impl UWheelOptimizer {
372382
}
373383
None
374384
}
385+
386+
LogicalPlan::Aggregate(agg) if Self::multiple_aggregates(agg) => {
387+
// Only continue if the aggregation has a filter
388+
let LogicalPlan::Filter(filter) = agg.input.as_ref() else {
389+
return None;
390+
};
391+
392+
let agg_exprs = &agg.aggr_expr;
393+
394+
let mut agg_results = Vec::new();
395+
396+
for agg_expr in agg_exprs {
397+
match agg_expr {
398+
// Single Aggregate Function (e.g., SUM(col))
399+
Expr::AggregateFunction(agg) if agg.args.len() == 1 => {
400+
if let Expr::Column(col) = &agg.args[0] {
401+
// Fetch temporal filter range and expr key which is used to identify a wheel
402+
let (range, expr_key) = match extract_filter_expr(
403+
&filter.predicate,
404+
&self.time_column,
405+
)? {
406+
(range, Some(expr)) => {
407+
(range, maybe_replace_table_name(&expr, &self.name))
408+
}
409+
(range, None) => (range, STAR_AGGREGATION_ALIAS.to_string()),
410+
};
411+
412+
// build the key for the wheel
413+
let wheel_key = format!("{}.{}.{}", self.name, col.name, expr_key);
414+
415+
let agg_type = func_def_to_aggregate_type(&agg.func)?;
416+
417+
// get aggregation result
418+
let result =
419+
self.get_aggregate_result(agg_type, &wheel_key, range)?;
420+
421+
agg_results.push(result);
422+
} else {
423+
return None;
424+
}
425+
}
426+
_ => {
427+
return None;
428+
}
429+
}
430+
}
431+
432+
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
433+
434+
uwheel_multiple_aggregations_to_table_scan(agg_results, schema).ok()
435+
}
375436
// Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
376437
LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan),
377438
_ => None,
@@ -453,27 +514,32 @@ impl UWheelOptimizer {
453514
range: WheelRange,
454515
schema: SchemaRef,
455516
) -> Option<LogicalPlan> {
517+
let result = self.get_aggregate_result(agg_type, wheel_key, range)?;
518+
uwheel_agg_to_table_scan(result, schema).ok()
519+
}
520+
521+
fn get_aggregate_result(
522+
&self,
523+
agg_type: UWheelAggregate,
524+
wheel_key: &str,
525+
range: WheelRange,
526+
) -> Option<f64> {
456527
match agg_type {
457528
UWheelAggregate::Sum => {
458529
let wheel = self.wheels.sum.lock().unwrap().get(wheel_key)?.clone();
459-
let result = wheel.combine_range_and_lower(range)?;
460-
uwheel_agg_to_table_scan(result, schema).ok()
530+
wheel.combine_range_and_lower(range)
461531
}
462532
UWheelAggregate::Avg => {
463533
let wheel = self.wheels.avg.lock().unwrap().get(wheel_key)?.clone();
464-
let result = wheel.combine_range_and_lower(range)?;
465-
466-
uwheel_agg_to_table_scan(result, schema).ok()
534+
wheel.combine_range_and_lower(range)
467535
}
468536
UWheelAggregate::Min => {
469537
let wheel = self.wheels.min.lock().unwrap().get(wheel_key)?.clone();
470-
let result = wheel.combine_range_and_lower(range)?;
471-
uwheel_agg_to_table_scan(result, schema).ok()
538+
wheel.combine_range_and_lower(range)
472539
}
473540
UWheelAggregate::Max => {
474541
let wheel = self.wheels.max.lock().unwrap().get(wheel_key)?.clone();
475-
let result = wheel.combine_range_and_lower(range)?;
476-
uwheel_agg_to_table_scan(result, schema).ok()
542+
wheel.combine_range_and_lower(range)
477543
}
478544
_ => unimplemented!(),
479545
}
@@ -517,6 +583,24 @@ fn uwheel_group_by_to_table_scan(
517583
mem_table_as_table_scan(mem_table, df_schema)
518584
}
519585

586+
fn uwheel_multiple_aggregations_to_table_scan(
587+
agg_results: Vec<f64>,
588+
schema: SchemaRef,
589+
) -> Result<LogicalPlan> {
590+
let mut columns = Vec::new();
591+
592+
for result in agg_results {
593+
let data = Float64Array::from(vec![result]);
594+
columns.push(Arc::new(data) as Arc<dyn Array>);
595+
}
596+
597+
let record_batch = RecordBatch::try_new(schema.clone(), columns)?;
598+
599+
let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
600+
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
601+
mem_table_as_table_scan(mem_table, df_schema)
602+
}
603+
520604
// helper for possibly removing the table name from the expression key
521605
fn maybe_replace_table_name(expr: &Expr, table_name: &str) -> String {
522606
let expr_str = expr.to_string();
@@ -1575,4 +1659,116 @@ mod tests {
15751659

15761660
Ok(())
15771661
}
1662+
1663+
#[tokio::test]
1664+
async fn multiple_aggregation_rewrite() -> Result<()> {
1665+
let optimizer = test_optimizer().await?;
1666+
1667+
optimizer
1668+
.build_index(IndexBuilder::with_col_and_aggregate(
1669+
"agg_col",
1670+
UWheelAggregate::Avg,
1671+
))
1672+
.await?;
1673+
1674+
optimizer
1675+
.build_index(IndexBuilder::with_col_and_aggregate(
1676+
"agg_col",
1677+
UWheelAggregate::Sum,
1678+
))
1679+
.await?;
1680+
1681+
let temporal_filter = col("timestamp")
1682+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1683+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1684+
1685+
let plan =
1686+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1687+
.filter(temporal_filter)?
1688+
.aggregate(
1689+
Vec::<Expr>::new(),
1690+
vec![avg(col("agg_col")), sum(col("agg_col"))],
1691+
)?
1692+
.project(vec![avg(col("agg_col")), sum(col("agg_col"))])?
1693+
.build()?;
1694+
1695+
// Assert that the original plan is a Projection
1696+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1697+
1698+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1699+
// assert it was rewritten to a TableScan
1700+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1701+
1702+
Ok(())
1703+
}
1704+
1705+
#[tokio::test]
1706+
async fn multiple_aggregation_exec() -> Result<()> {
1707+
let optimizer = test_optimizer().await?;
1708+
1709+
optimizer
1710+
.build_index(IndexBuilder::with_col_and_aggregate(
1711+
"agg_col",
1712+
UWheelAggregate::Avg,
1713+
))
1714+
.await?;
1715+
1716+
optimizer
1717+
.build_index(IndexBuilder::with_col_and_aggregate(
1718+
"agg_col",
1719+
UWheelAggregate::Sum,
1720+
))
1721+
.await?;
1722+
1723+
let temporal_filter = col("timestamp")
1724+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1725+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1726+
1727+
let plan =
1728+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1729+
.filter(temporal_filter)?
1730+
.aggregate(
1731+
Vec::<Expr>::new(),
1732+
vec![avg(col("agg_col")), sum(col("agg_col"))],
1733+
)?
1734+
.project(vec![avg(col("agg_col")), sum(col("agg_col"))])?
1735+
.build()?;
1736+
1737+
let ctx = SessionContext::new();
1738+
ctx.register_table("test", optimizer.provider().clone())?;
1739+
1740+
// Set UWheelOptimizer as optimizer rule
1741+
let session_state = SessionStateBuilder::new()
1742+
.with_optimizer_rules(vec![optimizer.clone()])
1743+
.build();
1744+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1745+
1746+
// Run the query through the ctx that has our OptimizerRule
1747+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1748+
let results = df.collect().await?;
1749+
1750+
assert_eq!(results.len(), 1);
1751+
1752+
assert_eq!(
1753+
results[0]
1754+
.column(0)
1755+
.as_any()
1756+
.downcast_ref::<Float64Array>()
1757+
.unwrap()
1758+
.value(0),
1759+
5.5
1760+
);
1761+
1762+
assert_eq!(
1763+
results[0]
1764+
.column(1)
1765+
.as_any()
1766+
.downcast_ref::<Float64Array>()
1767+
.unwrap()
1768+
.value(0),
1769+
55.0
1770+
);
1771+
1772+
Ok(())
1773+
}
15781774
}

0 commit comments

Comments
 (0)