Skip to content

Commit d82a5c6

Browse files
authored
branch-3.1: [fix](nereids) Fix not in aggregate's output err after eliminate by uniform when group sets exist apache#56942 (apache#57885)
picked from apache#56942
1 parent 463044a commit d82a5c6

File tree

5 files changed

+157
-7
lines changed

5 files changed

+157
-7
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
3333
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
3434
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
35+
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
3536
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
3637
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
3738
import org.apache.doris.nereids.util.Utils;
@@ -81,6 +82,11 @@ public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
8182
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Map<ExprId, ExprId> replaceMap) {
8283
aggregate = visitChildren(this, aggregate, replaceMap);
8384
aggregate = (LogicalAggregate<? extends Plan>) exprIdReplacer.rewriteExpr(aggregate, replaceMap);
85+
if (aggregate.getSourceRepeat().isPresent()) {
86+
LogicalRepeat<?> sourceRepeat = (LogicalRepeat<?>) exprIdReplacer.rewriteExpr(
87+
aggregate.getSourceRepeat().get(), replaceMap);
88+
aggregate = aggregate.withSourceRepeat(sourceRepeat);
89+
}
8490

8591
if (aggregate.getGroupByExpressions().isEmpty() || aggregate.getSourceRepeat().isPresent()) {
8692
return aggregate;

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@
2828
import org.apache.doris.nereids.trees.expressions.Expression;
2929
import org.apache.doris.nereids.trees.expressions.Slot;
3030
import org.apache.doris.nereids.trees.expressions.SlotReference;
31+
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
32+
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
33+
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
3134
import org.apache.doris.nereids.trees.plans.Plan;
3235

3336
import com.google.common.collect.ImmutableList;
3437

3538
import java.util.List;
3639
import java.util.Map;
40+
import java.util.Optional;
3741

3842
/** replace SlotReference ExprId in logical plans */
3943
public class ExprIdRewriter extends ExpressionRewrite {
@@ -74,6 +78,25 @@ public Plan rewriteExpr(Plan plan, Map<ExprId, ExprId> replaceMap) {
7478
* SlotReference:a#0 -> a#3, a#1 -> a#7
7579
* */
7680
public static class ReplaceRule implements ExpressionPatternRuleFactory {
81+
private static final DefaultExpressionRewriter<Map<ExprId, ExprId>> SLOT_REPLACER =
82+
new DefaultExpressionRewriter<Map<ExprId, ExprId>>() {
83+
@Override
84+
public Expression visitSlotReference(SlotReference slot, Map<ExprId, ExprId> replaceMap) {
85+
ExprId newId = replaceMap.get(slot.getExprId());
86+
if (newId == null) {
87+
return slot;
88+
}
89+
ExprId lastId = newId;
90+
while (true) {
91+
newId = replaceMap.get(lastId);
92+
if (newId == null) {
93+
return slot.withExprId(lastId);
94+
} else {
95+
lastId = newId;
96+
}
97+
}
98+
}
99+
};
77100
private final Map<ExprId, ExprId> replaceMap;
78101

79102
public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
@@ -85,14 +108,30 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
85108
return ImmutableList.of(
86109
matchesType(SlotReference.class).thenApply(ctx -> {
87110
Slot slot = ctx.expr;
88-
if (replaceMap.containsKey(slot.getExprId())) {
89-
ExprId newId = replaceMap.get(slot.getExprId());
90-
while (replaceMap.containsKey(newId)) {
91-
newId = replaceMap.get(newId);
111+
return slot.accept(SLOT_REPLACER, replaceMap);
112+
}),
113+
matchesType(VirtualSlotReference.class).thenApply(ctx -> {
114+
VirtualSlotReference virtualSlot = ctx.expr;
115+
return virtualSlot.accept(new DefaultExpressionRewriter<Map<ExprId, ExprId>>() {
116+
@Override
117+
public Expression visitVirtualReference(VirtualSlotReference virtualSlot,
118+
Map<ExprId, ExprId> replaceMap) {
119+
Optional<GroupingScalarFunction> originExpression = virtualSlot.getOriginExpression();
120+
if (!originExpression.isPresent()) {
121+
return virtualSlot;
122+
}
123+
GroupingScalarFunction groupingScalarFunction = originExpression.get();
124+
GroupingScalarFunction rewrittenFunction =
125+
(GroupingScalarFunction) groupingScalarFunction.accept(
126+
SLOT_REPLACER, replaceMap);
127+
if (!rewrittenFunction.children().equals(groupingScalarFunction.children())) {
128+
return virtualSlot.withOriginExpressionAndComputeLongValueMethod(
129+
Optional.of(rewrittenFunction),
130+
rewrittenFunction::computeVirtualSlotValue);
131+
}
132+
return virtualSlot;
92133
}
93-
return slot.withExprId(newId);
94-
}
95-
return slot;
134+
}, replaceMap);
96135
})
97136
);
98137
}

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ public VirtualSlotReference withExprId(ExprId exprId) {
156156
originExpression, computeLongValueMethod);
157157
}
158158

159+
public VirtualSlotReference withOriginExpressionAndComputeLongValueMethod(
160+
Optional<GroupingScalarFunction> originExpression,
161+
Function<GroupingSetShapes, List<Long>> computeLongValueMethod) {
162+
return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier,
163+
originExpression, computeLongValueMethod);
164+
}
165+
159166
@Override
160167
public Slot withIndexInSql(Pair<Integer, Integer> index) {
161168
return this;

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ public LogicalAggregate<Plan> withNormalized(List<Expression> normalizedGroupBy,
311311
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild);
312312
}
313313

314+
public LogicalAggregate<Plan> withSourceRepeat(LogicalRepeat<?> sourceRepeat) {
315+
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved,
316+
generated, hasPushed, Optional.ofNullable(sourceRepeat),
317+
Optional.empty(), Optional.empty(), child());
318+
}
319+
314320
private boolean isUniqueGroupByUnique(NamedExpression namedExpression) {
315321
if (namedExpression.children().size() != 1) {
316322
return false;

regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,96 @@ suite("eliminate_group_by_key_by_uniform") {
237237
qt_to_limit_join_project_shape "explain shape plan select 1 as c1 from test1 t1 inner join (select * from test2 where b=105) t2 on t1.a=t2.a group by c1;"
238238
qt_to_limit_project_uniform_shape "explain shape plan select 1 as c1 from eli_gbk_by_uniform_t group by c1"
239239
qt_to_limit_multi_group_by_shape "explain shape plan select 2 as c1 from eli_gbk_by_uniform_t where a=1 group by c1,a"
240+
241+
// test when has repeat above agg
242+
243+
sql """drop table if exists test_event"""
244+
sql """
245+
CREATE TABLE `test_event` (
246+
`@dt` DATETIME NOT NULL COMMENT '',
247+
`@event_name` VARCHAR(255) NOT NULL COMMENT '',
248+
`@user_id` VARCHAR(100) NOT NULL COMMENT '',
249+
`@event_time` DATETIME NOT NULL COMMENT '',
250+
`@event_property_1` VARCHAR(255) NULL
251+
)
252+
ENGINE=OLAP
253+
DUPLICATE KEY(`@dt`, `@event_name`, `@user_id`)
254+
COMMENT ''
255+
PARTITION BY RANGE(`@dt`)
256+
(
257+
PARTITION p202509 VALUES [('2025-09-01 00:00:00'), ('2025-10-05 00:00:00'))
258+
)
259+
DISTRIBUTED BY HASH(`@user_id`) BUCKETS 10
260+
PROPERTIES (
261+
"replication_num" = "1",
262+
"dynamic_partition.enable" = "true",
263+
"dynamic_partition.time_unit" = "MONTH",
264+
"dynamic_partition.start" = "-2147483648",
265+
"dynamic_partition.end" = "3",
266+
"dynamic_partition.prefix" = "p",
267+
"dynamic_partition.buckets" = "10"
268+
);
269+
"""
270+
271+
sql """
272+
INSERT INTO `test_event` (`@dt`, `@event_name`, `@user_id`, `@event_time`, `@event_property_1`)
273+
VALUES
274+
('2025-09-03 10:00:00', 'shop_buy', 'user_A', '2025-09-03 10:00:00', 'prop_A1'),
275+
('2025-09-03 10:01:00', 'shop_buy', 'user_A', '2025-09-03 10:01:00', 'prop_A2'),
276+
('2025-09-04 15:30:00', 'shop_buy', 'user_A', '2025-09-04 15:30:00', 'prop_A3'),
277+
('2025-09-05 08:00:00', 'shop_buy', 'user_B', '2025-09-05 08:00:00', 'prop_B1'),
278+
('2025-09-05 08:05:00', 'shop_buy', 'user_B', '2025-09-05 08:05:00', 'prop_B2'),
279+
('2025-09-09 23:59:59', 'shop_buy', 'user_C', '2025-09-09 23:59:59', 'prop_C1'),
280+
('2025-10-01 00:00:00', 'shop_buy', 'user_D', '2025-10-01 00:00:00', 'prop_D1');
281+
"""
282+
283+
sql """
284+
SELECT
285+
CASE WHEN GROUPING(event_date) = 1 THEN '(TOTAL)' ELSE CAST(event_date AS VARCHAR) END AS event_date,
286+
user_id,
287+
MAX(conversion_level) AS conversion_level,
288+
CASE WHEN GROUPING(event_name_group) = 1 THEN '(TOTAL)' ELSE event_name_group END AS event_name_group
289+
FROM
290+
(
291+
SELECT
292+
src.event_date,
293+
src.user_id,
294+
WINDOW_FUNNEL(
295+
3600 * 24 * 1,
296+
'default',
297+
src.event_time,
298+
src.event_name = 'shop_buy',
299+
src.event_name = 'shop_buy'
300+
) AS conversion_level,
301+
src.event_name_group
302+
FROM
303+
(
304+
SELECT
305+
CAST(etb.`@dt` AS DATE) AS event_date,
306+
etb.`@event_name` AS event_name,
307+
etb.`@event_time` AS event_time,
308+
etb.`@event_name` AS event_name_group,
309+
etb.`@user_id` AS user_id
310+
FROM
311+
`test_event` AS etb
312+
WHERE
313+
etb.`@dt` between '2025-09-03 02:00:00' AND '2025-09-10 01:59:59'
314+
AND etb.`@event_name` = 'shop_buy'
315+
AND etb.`@user_id` IS NOT NULL
316+
AND etb.`@user_id` > '0'
317+
) AS src
318+
GROUP BY
319+
src.event_date,
320+
src.user_id,
321+
src.event_name_group
322+
) AS fwt
323+
GROUP BY
324+
GROUPING SETS (
325+
(user_id),
326+
(user_id, event_date),
327+
(user_id, event_name_group),
328+
(user_id, event_date, event_name_group)
329+
);
330+
331+
"""
240332
}

0 commit comments

Comments
 (0)