Skip to content

Commit 45eef60

Browse files
Fix!: use select star when eliminating distinct on (#4401)
Current apporach of re-using the select from the inner query would break in a lot of cases (e.g. using table aliases x.a, aliasing/renaming columns, using an expression, etc..) There was a previous attempt to fix in #4286 and this implements the simplest approach suggested in the discussion there (which is SELECT *)
1 parent 8d78add commit 45eef60

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

sqlglot/transforms.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
183183
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
184184
):
185185
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
186-
outer_selects = expression.selects
187186
row_number = find_new_name(expression.named_selects, "_row_number")
188187
window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189188
order = expression.args.get("order")
@@ -197,7 +196,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
197196
expression.select(window, copy=False)
198197

199198
return (
200-
exp.select(*outer_selects, copy=False)
199+
exp.select("*", copy=False)
201200
.from_(expression.subquery("_t", copy=False), copy=False)
202201
.where(exp.column(row_number).eq(1), copy=False)
203202
)

tests/dialects/test_redshift.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -228,22 +228,22 @@ def test_redshift(self):
228228
self.validate_all(
229229
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
230230
write={
231-
"bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
232-
"databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
233-
"drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
234-
"hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
235-
"mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
236-
"oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1",
237-
"presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
238-
"redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
239-
"snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
240-
"spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
241-
"sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
242-
"starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
243-
"tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
244-
"teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
245-
"trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
246-
"tsql": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
231+
"bigquery": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
232+
"databricks": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
233+
"drill": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
234+
"hive": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
235+
"mysql": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
236+
"oracle": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1",
237+
"presto": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
238+
"redshift": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
239+
"snowflake": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
240+
"spark": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
241+
"sqlite": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
242+
"starrocks": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
243+
"tableau": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
244+
"teradata": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
245+
"trino": "SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
246+
"tsql": "SELECT * FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
247247
},
248248
)
249249
self.validate_all(

tests/test_transforms.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ def test_eliminate_distinct_on(self):
5555
self.validate(
5656
eliminate_distinct_on,
5757
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
58-
"SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
58+
"SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
5959
)
6060
self.validate(
6161
eliminate_distinct_on,
6262
"SELECT DISTINCT ON (a) a, b FROM x",
63-
"SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1",
63+
"SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1",
6464
)
6565
self.validate(
6666
eliminate_distinct_on,
6767
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
68-
"SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
68+
"SELECT * FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
6969
)
7070
self.validate(
7171
eliminate_distinct_on,
@@ -75,7 +75,12 @@ def test_eliminate_distinct_on(self):
7575
self.validate(
7676
eliminate_distinct_on,
7777
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
78-
"SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1",
78+
"SELECT * FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1",
79+
)
80+
self.validate(
81+
eliminate_distinct_on,
82+
"SELECT DISTINCT ON (x.a, x.b) x.a, x.b FROM x ORDER BY c DESC",
83+
"SELECT * FROM (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a, x.b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
7984
)
8085

8186
def test_eliminate_qualify(self):

0 commit comments

Comments
 (0)