Skip to content

Commit aa26aad

Browse files
authored
Feat: transpile WINDOW clause (#5097)
1 parent 9d3a929 commit aa26aad

File tree

7 files changed

+138
-69
lines changed

7 files changed

+138
-69
lines changed

sqlglot/dialects/presto.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> s
220220
return explode_to_unnest_sql(self, expression)
221221

222222

223-
def _amend_exploded_column_table(expression: exp.Expression) -> exp.Expression:
223+
def amend_exploded_column_table(expression: exp.Expression) -> exp.Expression:
224224
# We check for expression.type because the columns can be amended only if types were inferred
225225
if isinstance(expression, exp.Select) and expression.type:
226226
for lateral in expression.args.get("laterals") or []:
@@ -484,11 +484,12 @@ class Generator(generator.Generator):
484484
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
485485
exp.Select: transforms.preprocess(
486486
[
487+
transforms.eliminate_window_clause,
487488
transforms.eliminate_qualify,
488489
transforms.eliminate_distinct_on,
489490
transforms.explode_projection_to_unnest(1),
490491
transforms.eliminate_semi_and_anti_joins,
491-
_amend_exploded_column_table,
492+
amend_exploded_column_table,
492493
]
493494
),
494495
exp.SortArray: _no_sort_array,

sqlglot/dialects/redshift.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class Generator(Postgres.Generator):
197197
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
198198
exp.Select: transforms.preprocess(
199199
[
200+
transforms.eliminate_window_clause,
200201
transforms.eliminate_distinct_on,
201202
transforms.eliminate_semi_and_anti_joins,
202203
transforms.unqualify_unnest,

sqlglot/dialects/snowflake.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,7 @@ class Generator(generator.Generator):
10781078
exp.Rand: rename_func("RANDOM"),
10791079
exp.Select: transforms.preprocess(
10801080
[
1081+
transforms.eliminate_window_clause,
10811082
transforms.eliminate_distinct_on,
10821083
transforms.explode_projection_to_unnest(),
10831084
transforms.eliminate_semi_and_anti_joins,

sqlglot/dialects/trino.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3-
from sqlglot import exp, parser
3+
from sqlglot import exp, parser, transforms
44
from sqlglot.dialects.dialect import (
55
merge_without_target_sql,
66
trim_sql,
77
timestrtotime_sql,
88
groupconcat_sql,
99
)
10-
from sqlglot.dialects.presto import Presto
10+
from sqlglot.dialects.presto import amend_exploded_column_table, Presto
1111
from sqlglot.tokens import TokenType
1212
import typing as t
1313

@@ -75,6 +75,15 @@ class Generator(Presto.Generator):
7575
exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, on_overflow=True),
7676
exp.LocationProperty: lambda self, e: self.property_sql(e),
7777
exp.Merge: merge_without_target_sql,
78+
exp.Select: transforms.preprocess(
79+
[
80+
transforms.eliminate_qualify,
81+
transforms.eliminate_distinct_on,
82+
transforms.explode_projection_to_unnest(1),
83+
transforms.eliminate_semi_and_anti_joins,
84+
amend_exploded_column_table,
85+
]
86+
),
7887
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True),
7988
exp.Trim: trim_sql,
8089
}

sqlglot/transforms.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -965,16 +965,47 @@ def any_to_exists(expression: exp.Expression) -> exp.Expression:
965965
transformation
966966
"""
967967
if isinstance(expression, exp.Select):
968-
for any in expression.find_all(exp.Any):
969-
this = any.this
968+
for any_expr in expression.find_all(exp.Any):
969+
this = any_expr.this
970970
if isinstance(this, exp.Query):
971971
continue
972972

973-
binop = any.parent
973+
binop = any_expr.parent
974974
if isinstance(binop, exp.Binary):
975975
lambda_arg = exp.to_identifier("x")
976-
any.replace(lambda_arg)
976+
any_expr.replace(lambda_arg)
977977
lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
978978
binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
979979

980980
return expression
981+
982+
983+
def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
984+
"""Eliminates the `WINDOW` query clause by inling each named window."""
985+
if isinstance(expression, exp.Select) and expression.args.get("windows"):
986+
from sqlglot.optimizer.scope import find_all_in_scope
987+
988+
windows = expression.args["windows"]
989+
expression.set("windows", None)
990+
991+
window_expression: t.Dict[str, exp.Expression] = {}
992+
993+
def _inline_inherited_window(window: exp.Expression) -> None:
994+
inherited_window = window_expression.get(window.alias.lower())
995+
if not inherited_window:
996+
return
997+
998+
window.set("alias", None)
999+
for key in ("partition_by", "order", "spec"):
1000+
arg = inherited_window.args.get(key)
1001+
if arg:
1002+
window.set(key, arg.copy())
1003+
1004+
for window in windows:
1005+
_inline_inherited_window(window)
1006+
window_expression[window.name.lower()] = window
1007+
1008+
for window in find_all_in_scope(expression, exp.Window):
1009+
_inline_inherited_window(window)
1010+
1011+
return expression

tests/dialects/test_bigquery.py

Lines changed: 74 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -24,67 +24,6 @@ class TestBigQuery(Validator):
2424
maxDiff = None
2525

2626
def test_bigquery(self):
27-
self.validate_all(
28-
"EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
29-
write={
30-
"bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
31-
"duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
32-
"snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
33-
},
34-
)
35-
self.validate_identity(
36-
"""CREATE TEMPORARY FUNCTION FOO()
37-
RETURNS STRING
38-
LANGUAGE js AS
39-
'return "Hello world!"'""",
40-
pretty=True,
41-
)
42-
self.validate_identity(
43-
"[a, a(1, 2,3,4444444444444444, tttttaoeunthaoentuhaoentuheoantu, toheuntaoheutnahoeunteoahuntaoeh), b(3, 4,5), c, d, tttttttttttttttteeeeeeeeeeeeeett, 12312312312]",
44-
"""[
45-
a,
46-
a(
47-
1,
48-
2,
49-
3,
50-
4444444444444444,
51-
tttttaoeunthaoentuhaoentuheoantu,
52-
toheuntaoheutnahoeunteoahuntaoeh
53-
),
54-
b(3, 4, 5),
55-
c,
56-
d,
57-
tttttttttttttttteeeeeeeeeeeeeett,
58-
12312312312
59-
]""",
60-
pretty=True,
61-
)
62-
63-
self.validate_all(
64-
"SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 as a, 'abc' AS b), STRUCT(str_col AS abc)",
65-
write={
66-
"bigquery": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
67-
"duckdb": "SELECT {'_0': 1, '_1': 2, '_2': 3}, {}, {'_0': 'abc'}, {'_0': 1, '_1': t.str_col}, {'a': 1, 'b': 'abc'}, {'abc': str_col}",
68-
"hive": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1, 'abc'), STRUCT(str_col)",
69-
"spark2": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
70-
"spark": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
71-
"snowflake": "SELECT OBJECT_CONSTRUCT('_0', 1, '_1', 2, '_2', 3), OBJECT_CONSTRUCT(), OBJECT_CONSTRUCT('_0', 'abc'), OBJECT_CONSTRUCT('_0', 1, '_1', t.str_col), OBJECT_CONSTRUCT('a', 1, 'b', 'abc'), OBJECT_CONSTRUCT('abc', str_col)",
72-
# fallback to unnamed without type inference
73-
"trino": "SELECT ROW(1, 2, 3), ROW(), ROW('abc'), ROW(1, t.str_col), CAST(ROW(1, 'abc') AS ROW(a INTEGER, b VARCHAR)), ROW(str_col)",
74-
},
75-
)
76-
self.validate_all(
77-
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
78-
write={
79-
"bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
80-
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S.%f%z')",
81-
},
82-
)
83-
self.validate_identity(
84-
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E*S%z', x)",
85-
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E*S%z', x)",
86-
)
87-
8827
for prefix in ("c.db.", "db.", ""):
8928
with self.subTest(f"Parsing {prefix}INFORMATION_SCHEMA.X into a Table"):
9029
table = self.parse_one(f"`{prefix}INFORMATION_SCHEMA.X`", into=exp.Table)
@@ -116,6 +55,7 @@ def test_bigquery(self):
11655
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
11756
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")
11857

58+
self.validate_identity("PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E*S%z', x)")
11959
self.validate_identity("SELECT ARRAY_CONCAT([1])")
12060
self.validate_identity("SELECT * FROM READ_CSV('bla.csv')")
12161
self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)")
@@ -321,7 +261,80 @@ def test_bigquery(self):
321261
"SELECT CAST(1 AS BYTEINT)",
322262
"SELECT CAST(1 AS INT64)",
323263
)
264+
self.validate_identity(
265+
"""CREATE TEMPORARY FUNCTION FOO()
266+
RETURNS STRING
267+
LANGUAGE js AS
268+
'return "Hello world!"'""",
269+
pretty=True,
270+
)
271+
self.validate_identity(
272+
"[a, a(1, 2,3,4444444444444444, tttttaoeunthaoentuhaoentuheoantu, toheuntaoheutnahoeunteoahuntaoeh), b(3, 4,5), c, d, tttttttttttttttteeeeeeeeeeeeeett, 12312312312]",
273+
"""[
274+
a,
275+
a(
276+
1,
277+
2,
278+
3,
279+
4444444444444444,
280+
tttttaoeunthaoentuhaoentuheoantu,
281+
toheuntaoheutnahoeunteoahuntaoeh
282+
),
283+
b(3, 4, 5),
284+
c,
285+
d,
286+
tttttttttttttttteeeeeeeeeeeeeett,
287+
12312312312
288+
]""",
289+
pretty=True,
290+
)
324291

292+
self.validate_all(
293+
"SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
294+
write={
295+
"bigquery": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
296+
"clickhouse": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
297+
"databricks": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
298+
"duckdb": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
299+
"mysql": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
300+
"oracle": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
301+
"postgres": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
302+
"presto": "SELECT purchases, LAST_VALUE(item) OVER (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce",
303+
"redshift": "SELECT purchases, LAST_VALUE(item) OVER (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce",
304+
"snowflake": "SELECT purchases, LAST_VALUE(item) OVER (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce",
305+
"spark": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
306+
"trino": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
307+
"tsql": "SELECT purchases, LAST_VALUE(item) OVER item_window AS most_popular FROM Produce WINDOW item_window AS (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)",
308+
},
309+
)
310+
self.validate_all(
311+
"EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
312+
write={
313+
"bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
314+
"duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
315+
"snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
316+
},
317+
)
318+
self.validate_all(
319+
"SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 as a, 'abc' AS b), STRUCT(str_col AS abc)",
320+
write={
321+
"bigquery": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
322+
"duckdb": "SELECT {'_0': 1, '_1': 2, '_2': 3}, {}, {'_0': 'abc'}, {'_0': 1, '_1': t.str_col}, {'a': 1, 'b': 'abc'}, {'abc': str_col}",
323+
"hive": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1, 'abc'), STRUCT(str_col)",
324+
"spark2": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
325+
"spark": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
326+
"snowflake": "SELECT OBJECT_CONSTRUCT('_0', 1, '_1', 2, '_2', 3), OBJECT_CONSTRUCT(), OBJECT_CONSTRUCT('_0', 'abc'), OBJECT_CONSTRUCT('_0', 1, '_1', t.str_col), OBJECT_CONSTRUCT('a', 1, 'b', 'abc'), OBJECT_CONSTRUCT('abc', str_col)",
327+
# fallback to unnamed without type inference
328+
"trino": "SELECT ROW(1, 2, 3), ROW(), ROW('abc'), ROW(1, t.str_col), CAST(ROW(1, 'abc') AS ROW(a INTEGER, b VARCHAR)), ROW(str_col)",
329+
},
330+
)
331+
self.validate_all(
332+
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
333+
write={
334+
"bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
335+
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S.%f%z')",
336+
},
337+
)
325338
self.validate_all(
326339
"SELECT DATE_SUB(CURRENT_DATE(), INTERVAL 2 DAY)",
327340
write={

tests/test_transforms.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
eliminate_distinct_on,
66
eliminate_join_marks,
77
eliminate_qualify,
8+
eliminate_window_clause,
89
remove_precision_parameterized_types,
910
unalias_group,
1011
)
@@ -272,3 +273,15 @@ def test_eliminate_join_marks(self):
272273
tree.sql(dialect=dialect)
273274
== "SELECT a.id FROM a LEFT JOIN b ON a.id = b.id AND b.d = const"
274275
)
276+
277+
def test_eliminate_window_clause(self):
278+
self.validate(
279+
eliminate_window_clause,
280+
"SELECT purchases, LAST_VALUE(item) OVER (d) AS most_popular FROM Produce WINDOW a AS (PARTITION BY purchases), b AS (a ORDER BY purchases), c AS (b ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING), d AS (c)",
281+
"SELECT purchases, LAST_VALUE(item) OVER (PARTITION BY purchases ORDER BY purchases ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce",
282+
)
283+
self.validate(
284+
eliminate_window_clause,
285+
"SELECT LAST_VALUE(c) OVER (a) AS c2 FROM (SELECT LAST_VALUE(i) OVER (a) AS c FROM p WINDOW a AS (PARTITION BY x)) AS q(c) WINDOW a AS (PARTITION BY y)",
286+
"SELECT LAST_VALUE(c) OVER (PARTITION BY y) AS c2 FROM (SELECT LAST_VALUE(i) OVER (PARTITION BY x) AS c FROM p) AS q(c)",
287+
)

0 commit comments

Comments
 (0)