Skip to content

Commit c824a32

Browse files
feiniaofeiafeizzzxl1993
authored andcommitted
[fix](agg) fix rule merge_aggregate (apache#59629)
Related PR: apache#31811 Problem Summary: Before this pr: The LogicalAggregate generated by MergeAggregate outputExpressions has duplicated column . This bug will not lead to result wrong, will output the gby key 2 times in LogicalAggregate. After this pr: This pr fix this problem.
1 parent d4dce7a commit c824a32

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate
100100
Map<ExprId, Expression> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
101101
// rewrite agg function. e.g. max(max)
102102
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
103+
.filter(e -> e.containsType(AggregateFunction.class))
103104
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
104105
.collect(Collectors.toList());
105106
// replace groupByKeys directly refer to the slot below the project
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.rules.rewrite;
19+
20+
import org.apache.doris.nereids.trees.expressions.Alias;
21+
import org.apache.doris.nereids.trees.expressions.NamedExpression;
22+
import org.apache.doris.nereids.trees.expressions.Slot;
23+
import org.apache.doris.nereids.trees.expressions.SlotReference;
24+
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
25+
import org.apache.doris.nereids.trees.plans.Plan;
26+
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
27+
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
28+
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
29+
import org.apache.doris.nereids.types.IntegerType;
30+
31+
import com.google.common.collect.ImmutableList;
32+
import org.junit.jupiter.api.Assertions;
33+
import org.junit.jupiter.api.BeforeEach;
34+
import org.junit.jupiter.api.Test;
35+
36+
import java.lang.reflect.Method;
37+
import java.util.List;
38+
39+
/**
40+
* Unit tests for {@link MergeAggregate}, specifically testing the fix for filtering
41+
* aggregate functions in mergeAggProjectAgg method.
42+
*/
43+
public class MergeAggregateTest {
44+
45+
private MergeAggregate mergeAggregate;
46+
47+
@BeforeEach
48+
public void setUp() {
49+
mergeAggregate = new MergeAggregate();
50+
}
51+
52+
@Test
53+
public void testMergeAggProjectAggWithMixedExpressions() throws Exception {
54+
// This test verifies the fix at line 103-104 where we filter expressions
55+
// to only process those containing AggregateFunction.
56+
// The bug was that non-aggregate expressions (like SlotReference) were
57+
// being passed to rewriteAggregateFunction, which could cause errors.
58+
59+
// Create inner aggregate: group by a, output a, sum(b) as sumBAlias
60+
SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
61+
SlotReference b = new SlotReference("b", IntegerType.INSTANCE);
62+
Sum sumB = new Sum(b);
63+
Alias sumBAlias = new Alias(sumB, "sumBAlias");
64+
65+
LogicalEmptyRelation emptyRelation = new LogicalEmptyRelation(
66+
org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator.newRelationId(),
67+
ImmutableList.of());
68+
69+
LogicalAggregate<Plan> innerAgg = new LogicalAggregate<>(
70+
ImmutableList.of(a),
71+
ImmutableList.of(a, sumBAlias),
72+
emptyRelation);
73+
74+
// Create project: projects = [a as colA, sumBAlias]
75+
SlotReference colA = new SlotReference(
76+
sumBAlias.getExprId(), "colA", IntegerType.INSTANCE, true, ImmutableList.of());
77+
// Create a slot reference for sumBAlias from inner aggregate output
78+
Slot sumBSlot = sumBAlias.toSlot();
79+
LogicalProject<LogicalAggregate<Plan>> project = new LogicalProject<>(
80+
ImmutableList.of(colA, sumBSlot),
81+
innerAgg);
82+
83+
// Create outer aggregate: group by colA, output colA, sum(sumBAlias)
84+
Slot col2FromProject = project.getOutput().get(0);
85+
Slot col1FromProject = project.getOutput().get(1);
86+
Sum sumSum = new Sum(col1FromProject);
87+
Alias sumSumAlias = new Alias(sumSum, "sumSum");
88+
89+
// Outer aggregate output contains:
90+
// 1. colA (SlotReference - non-aggregate, should be filtered out)
91+
// 2. max(sumBAlias) (AggregateFunction - should be processed)
92+
// 3. sumBAlias (SlotReference - non-aggregate, should be filtered out)
93+
List<NamedExpression> outerAggOutput = ImmutableList.of(
94+
col2FromProject,
95+
sumSumAlias
96+
);
97+
98+
LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg = new LogicalAggregate<>(
99+
ImmutableList.of(col2FromProject),
100+
outerAggOutput,
101+
project);
102+
103+
// Use reflection to call the private method
104+
Method method = mergeAggregate.getClass().getDeclaredMethod(
105+
"mergeAggProjectAgg", LogicalAggregate.class);
106+
method.setAccessible(true);
107+
108+
// This should not throw an exception
109+
// Before the fix, non-aggregate expressions would be passed to rewriteAggregateFunction
110+
// which could cause errors. After the fix, only expressions containing AggregateFunction
111+
// are processed.
112+
Plan result = (Plan) method.invoke(mergeAggregate, outerAgg);
113+
114+
Assertions.assertNotNull(result);
115+
Assertions.assertTrue(result instanceof LogicalProject);
116+
117+
LogicalProject<Plan> resultProject = (LogicalProject<Plan>) result;
118+
Assertions.assertNotNull(resultProject.child(0));
119+
Assertions.assertTrue(resultProject.child(0) instanceof LogicalAggregate);
120+
121+
LogicalAggregate<Plan> aggregate = (LogicalAggregate<Plan>) resultProject.child(0);
122+
Assertions.assertEquals(aggregate.getOutput().size(), 2);
123+
}
124+
}

0 commit comments

Comments
 (0)