Skip to content

Commit 0e46cc7

Browse files
committed
Fix: refactor DISTINCT ON elimination transformation (#4407)
1 parent b00d857 commit 0e46cc7

File tree

3 files changed

+59
-28
lines changed

3 files changed

+59
-28
lines changed

sqlglot/transforms.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,26 +179,42 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
179179
if (
180180
isinstance(expression, exp.Select)
181181
and expression.args.get("distinct")
182-
and expression.args["distinct"].args.get("on")
183-
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
182+
and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
184183
):
184+
row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
185+
185186
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
186-
row_number = find_new_name(expression.named_selects, "_row_number")
187187
window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188-
order = expression.args.get("order")
189188

189+
order = expression.args.get("order")
190190
if order:
191191
window.set("order", order.pop())
192192
else:
193193
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194194

195-
window = exp.alias_(window, row_number)
195+
window = exp.alias_(window, row_number_window_alias)
196196
expression.select(window, copy=False)
197197

198+
# We add aliases to the projections so that we can safely reference them in the outer query
199+
new_selects = []
200+
taken_names = {row_number_window_alias}
201+
for select in expression.selects[:-1]:
202+
if select.is_star:
203+
new_selects = ["*"]
204+
break
205+
206+
if not isinstance(select, exp.Alias):
207+
alias = find_new_name(taken_names, select.output_name or "_col")
208+
select = select.replace(exp.alias_(select, alias))
209+
210+
output_name = select.output_name
211+
taken_names.add(output_name)
212+
new_selects.append(output_name)
213+
198214
return (
199-
exp.select("*", copy=False)
215+
exp.select(*new_selects, copy=False)
200216
.from_(expression.subquery("_t", copy=False), copy=False)
201-
.where(exp.column(row_number).eq(1), copy=False)
217+
.where(exp.column(row_number_window_alias).eq(1), copy=False)
202218
)
203219

204220
return expression

tests/dialects/test_redshift.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -228,22 +228,22 @@ def test_redshift(self):
228228
self.validate_all(
229229
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
230230
write={
231-
"bigquery": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
232-
"databricks": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
233-
"drill": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
234-
"hive": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
235-
"mysql": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
236-
"oracle": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1",
237-
"presto": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
238-
"redshift": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
239-
"snowflake": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
240-
"spark": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
241-
"sqlite": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
242-
"starrocks": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
243-
"tableau": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
244-
"teradata": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
245-
"trino": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
246-
"tsql": "SELECT * FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
231+
"bigquery": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
232+
"databricks": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
233+
"drill": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
234+
"hive": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
235+
"mysql": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
236+
"oracle": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1",
237+
"presto": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
238+
"redshift": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
239+
"snowflake": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
240+
"spark": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
241+
"sqlite": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
242+
"starrocks": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
243+
"tableau": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
244+
"teradata": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
245+
"trino": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
246+
"tsql": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
247247
},
248248
)
249249
self.validate_all(

tests/test_transforms.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ def test_eliminate_distinct_on(self):
5555
self.validate(
5656
eliminate_distinct_on,
5757
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
58-
"SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
58+
"SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
5959
)
6060
self.validate(
6161
eliminate_distinct_on,
6262
"SELECT DISTINCT ON (a) a, b FROM x",
63-
"SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1",
63+
"SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1",
6464
)
6565
self.validate(
6666
eliminate_distinct_on,
6767
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
68-
"SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
68+
"SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
6969
)
7070
self.validate(
7171
eliminate_distinct_on,
@@ -75,12 +75,27 @@ def test_eliminate_distinct_on(self):
7575
self.validate(
7676
eliminate_distinct_on,
7777
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
78-
"SELECT * FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1",
78+
"SELECT _row_number FROM (SELECT _row_number AS _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1",
7979
)
8080
self.validate(
8181
eliminate_distinct_on,
8282
"SELECT DISTINCT ON (x.a, x.b) x.a, x.b FROM x ORDER BY c DESC",
83-
"SELECT * FROM (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a, x.b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
83+
"SELECT a, b FROM (SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a, x.b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
84+
)
85+
self.validate(
86+
eliminate_distinct_on,
87+
"SELECT DISTINCT ON (a) x.a, y.a FROM x CROSS JOIN y ORDER BY c DESC",
88+
"SELECT a, a_2 FROM (SELECT x.a AS a, y.a AS a_2, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x CROSS JOIN y) AS _t WHERE _row_number = 1",
89+
)
90+
self.validate(
91+
eliminate_distinct_on,
92+
"SELECT DISTINCT ON (a) a, a + b FROM x ORDER BY c DESC",
93+
"SELECT a, _col FROM (SELECT a AS a, a + b AS _col, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
94+
)
95+
self.validate(
96+
eliminate_distinct_on,
97+
"SELECT DISTINCT ON (a) * FROM x ORDER BY c DESC",
98+
"SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
8499
)
85100

86101
def test_eliminate_qualify(self):

0 commit comments

Comments
 (0)