Skip to content

Commit 72cf4a4

Browse files
authored
feat(duckdb)!: Add support for PIVOT multiple IN clauses (#4964)
* feat(duckdb): Add support for PIVOT's multi IN clauses * PR Feedback 1 * PR Feedback 2
1 parent ad5b595 commit 72cf4a4

File tree

9 files changed

+145
-43
lines changed

9 files changed

+145
-43
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,12 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
193193
if expression.unpivot:
194194
expression = transforms.unqualify_columns(expression)
195195
else:
196-
field = expression.args.get("field")
197-
field_expr = seq_get(field.expressions if field else [], 0)
196+
for field in expression.fields:
197+
field_expr = seq_get(field.expressions if field else [], 0)
198198

199-
if isinstance(field_expr, exp.PivotAny):
200-
unqualified_field_expr = transforms.unqualify_columns(field_expr)
201-
t.cast(exp.Expression, field).set("expressions", unqualified_field_expr, 0)
199+
if isinstance(field_expr, exp.PivotAny):
200+
unqualified_field_expr = transforms.unqualify_columns(field_expr)
201+
t.cast(exp.Expression, field).set("expressions", unqualified_field_expr, 0)
202202

203203
return expression
204204

sqlglot/dialects/spark2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
104104
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
105105
"""
106106
if isinstance(expression, exp.Pivot):
107-
expression.set("field", transforms.unqualify_columns(expression.args["field"]))
107+
expression.set(
108+
"fields", [transforms.unqualify_columns(field) for field in expression.fields]
109+
)
108110

109111
return expression
110112

@@ -237,7 +239,7 @@ def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
237239

238240
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
239241
if len(aggregations) == 1:
240-
return [""]
242+
return []
241243
return pivot_column_names(aggregations, dialect="spark")
242244

243245
class Generator(Hive.Generator):

sqlglot/expressions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4322,7 +4322,7 @@ class Pivot(Expression):
43224322
"this": False,
43234323
"alias": False,
43244324
"expressions": False,
4325-
"field": False,
4325+
"fields": False,
43264326
"unpivot": False,
43274327
"using": False,
43284328
"group": False,
@@ -4336,6 +4336,10 @@ class Pivot(Expression):
43364336
def unpivot(self) -> bool:
43374337
return bool(self.args.get("unpivot"))
43384338

4339+
@property
4340+
def fields(self) -> t.List[Expression]:
4341+
return self.args.get("fields", [])
4342+
43394343

43404344
# https://duckdb.org/docs/sql/statements/unpivot#simplified-unpivot-syntax
43414345
# UNPIVOT ... INTO [NAME <col_name> VALUE <col_value>][...,]

sqlglot/generator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,7 +2068,15 @@ def pivot_sql(self, expression: exp.Pivot) -> str:
20682068
alias = self.sql(expression, "alias")
20692069
alias = f" AS {alias}" if alias else ""
20702070

2071-
field = self.sql(expression, "field")
2071+
fields = self.expressions(
2072+
expression,
2073+
"fields",
2074+
sep=" ",
2075+
dynamic=True,
2076+
new_line=True,
2077+
skip_first=True,
2078+
skip_last=True,
2079+
)
20722080

20732081
include_nulls = expression.args.get("include_nulls")
20742082
if include_nulls is not None:
@@ -2078,7 +2086,7 @@ def pivot_sql(self, expression: exp.Pivot) -> str:
20782086

20792087
default_on_null = self.sql(expression, "default_on_null")
20802088
default_on_null = f" DEFAULT ON NULL ({default_on_null})" if default_on_null else ""
2081-
return f"{self.seg(direction)}{nulls}({expressions} FOR {field}{default_on_null}{group}){alias}"
2089+
return f"{self.seg(direction)}{nulls}({expressions} FOR {fields}{default_on_null}{group}){alias}"
20822090

20832091
def version_sql(self, expression: exp.Version) -> str:
20842092
this = f"FOR {expression.name}"

sqlglot/optimizer/qualify_columns.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,14 @@ def validate_qualify_columns(expression: E) -> E:
140140

141141

142142
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
143-
name_column = []
144-
field = unpivot.args.get("field")
145-
if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
146-
name_column.append(field.this)
147-
143+
name_columns = [
144+
field.this
145+
for field in unpivot.fields
146+
if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
147+
]
148148
value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
149-
return itertools.chain(name_column, value_columns)
149+
150+
return itertools.chain(name_columns, value_columns)
150151

151152

152153
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
@@ -608,18 +609,19 @@ def _expand_stars(
608609
dialect = resolver.schema.dialect
609610

610611
pivot_output_columns = None
611-
pivot_exclude_columns = None
612+
pivot_exclude_columns: t.Set[str] = set()
612613

613614
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
614615
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
615616
if pivot.unpivot:
616617
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
617618

618-
field = pivot.args.get("field")
619-
if isinstance(field, exp.In):
620-
pivot_exclude_columns = {
621-
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
622-
}
619+
for field in pivot.fields:
620+
if isinstance(field, exp.In):
621+
pivot_exclude_columns.update(
622+
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
623+
)
624+
623625
else:
624626
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
625627

sqlglot/parser.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import typing as t
5+
import itertools
56
from collections import defaultdict
67

78
from sqlglot import exp
@@ -4242,7 +4243,13 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
42424243
if not self._match(TokenType.FOR):
42434244
self.raise_error("Expecting FOR")
42444245

4245-
field = self._parse_pivot_in()
4246+
fields = []
4247+
while True:
4248+
field = self._try_parse(self._parse_pivot_in)
4249+
if not field:
4250+
break
4251+
fields.append(field)
4252+
42464253
default_on_null = self._match_text_seq("DEFAULT", "ON", "NULL") and self._parse_wrapped(
42474254
self._parse_bitwise
42484255
)
@@ -4254,7 +4261,7 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
42544261
pivot = self.expression(
42554262
exp.Pivot,
42564263
expressions=expressions,
4257-
field=field,
4264+
fields=fields,
42584265
unpivot=unpivot,
42594266
include_nulls=include_nulls,
42604267
default_on_null=default_on_null,
@@ -4268,26 +4275,43 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
42684275
names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
42694276

42704277
columns: t.List[exp.Expression] = []
4271-
pivot_field_expressions = pivot.args["field"].expressions
4278+
all_fields = []
4279+
for pivot_field in pivot.fields:
4280+
pivot_field_expressions = pivot_field.expressions
4281+
4282+
# The `PivotAny` expression corresponds to `ANY ORDER BY <column>`; we can't infer in this case.
4283+
if isinstance(seq_get(pivot_field_expressions, 0), exp.PivotAny):
4284+
continue
4285+
4286+
all_fields.append(
4287+
[
4288+
fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name
4289+
for fld in pivot_field_expressions
4290+
]
4291+
)
4292+
4293+
if all_fields:
4294+
if names:
4295+
all_fields.append(names)
4296+
4297+
# Generate all possible combinations of the pivot columns
4298+
# e.g PIVOT(sum(...) as total FOR year IN (2000, 2010) FOR country IN ('NL', 'US'))
4299+
# generates the product between [[2000, 2010], ['NL', 'US'], ['total']]
4300+
for fld_parts_tuple in itertools.product(*all_fields):
4301+
fld_parts = list(fld_parts_tuple)
42724302

4273-
# The `PivotAny` expression corresponds to `ANY ORDER BY <column>`; we can't infer in this case.
4274-
if not isinstance(seq_get(pivot_field_expressions, 0), exp.PivotAny):
4275-
for fld in pivot_field_expressions:
4276-
field_name = fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name
4277-
for name in names:
4278-
if self.PREFIXED_PIVOT_COLUMNS:
4279-
name = f"{name}_{field_name}" if name else field_name
4280-
else:
4281-
name = f"{field_name}_{name}" if name else field_name
4303+
if names and self.PREFIXED_PIVOT_COLUMNS:
4304+
# Move the "name" to the front of the list
4305+
fld_parts.insert(0, fld_parts.pop(-1))
42824306

4283-
columns.append(exp.to_identifier(name))
4307+
columns.append(exp.to_identifier("_".join(fld_parts)))
42844308

42854309
pivot.set("columns", columns)
42864310

42874311
return pivot
42884312

42894313
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
4290-
return [agg.alias for agg in aggregations]
4314+
return [agg.alias for agg in aggregations if agg.alias]
42914315

42924316
def _parse_prewhere(self, skip_where_token: bool = False) -> t.Optional[exp.PreWhere]:
42934317
if not skip_where_token and not self._match(TokenType.PREWHERE):

tests/fixtures/optimizer/optimizer.sql

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -760,10 +760,12 @@ SELECT
760760
`_q_0`.`first_half_sales` AS `first_half_sales`,
761761
`_q_0`.`second_half_sales` AS `second_half_sales`
762762
FROM `produce` AS `produce`
763-
UNPIVOT((`first_half_sales`, `second_half_sales`) FOR `semesters` IN (
764-
(`produce`.`q1`, `produce`.`q2`) AS 'semester_1',
765-
(`produce`.`q3`, `produce`.`q4`) AS 'semester_2'
766-
)) AS `_q_0`;
763+
UNPIVOT((`first_half_sales`, `second_half_sales`) FOR
764+
`semesters` IN (
765+
(`produce`.`q1`, `produce`.`q2`) AS 'semester_1',
766+
(`produce`.`q3`, `produce`.`q4`) AS 'semester_2'
767+
)
768+
) AS `_q_0`;
767769

768770
# title: quoting is preserved
769771
# dialect: snowflake

tests/fixtures/optimizer/qualify_columns.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,10 @@ SELECT _q_0.c1 AS c1, _q_0.c2 AS c2 FROM VALUES ((1, 1), (2, 2)) AS _q_0(c1, c2)
512512
SELECT * FROM VALUES (1, 2, 3);
513513
SELECT _q_0.c1 AS c1 FROM VALUES ((1), (2), (3)) AS _q_0(c1);
514514

515+
# title: Expand PIVOT column combinations
516+
# dialect: duckdb
517+
WITH cities AS (SELECT * FROM (VALUES ('nl', 'amsterdam', 2000, 1005)) AS t(country, name, year, population)) SELECT * FROM cities PIVOT(SUM(population) AS total, COUNT(population) AS count FOR country IN ('nl', 'us') year IN (2000, 2010) name IN ('amsterdam', 'seattle'));
518+
WITH cities AS (SELECT t.country AS country, t.name AS name, t.year AS year, t.population AS population FROM (VALUES ('nl', 'amsterdam', 2000, 1005)) AS t(country, name, year, population)) SELECT _q_0.nl_2000_amsterdam_total AS nl_2000_amsterdam_total, _q_0.nl_2000_amsterdam_count AS nl_2000_amsterdam_count, _q_0.nl_2000_seattle_total AS nl_2000_seattle_total, _q_0.nl_2000_seattle_count AS nl_2000_seattle_count, _q_0.nl_2010_amsterdam_total AS nl_2010_amsterdam_total, _q_0.nl_2010_amsterdam_count AS nl_2010_amsterdam_count, _q_0.nl_2010_seattle_total AS nl_2010_seattle_total, _q_0.nl_2010_seattle_count AS nl_2010_seattle_count, _q_0.us_2000_amsterdam_total AS us_2000_amsterdam_total, _q_0.us_2000_amsterdam_count AS us_2000_amsterdam_count, _q_0.us_2000_seattle_total AS us_2000_seattle_total, _q_0.us_2000_seattle_count AS us_2000_seattle_count, _q_0.us_2010_amsterdam_total AS us_2010_amsterdam_total, _q_0.us_2010_amsterdam_count AS us_2010_amsterdam_count, _q_0.us_2010_seattle_total AS us_2010_seattle_total, _q_0.us_2010_seattle_count AS us_2010_seattle_count FROM cities AS cities PIVOT(SUM(population) AS total, COUNT(population) AS count FOR country IN ('nl', 'us') year IN (2000, 2010) name IN ('amsterdam', 'seattle')) AS _q_0;
515519

516520
--------------------------------------
517521
-- CTEs

tests/test_parser.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,27 @@ def test_pivot_columns(self):
646646
) PIVOT (AVG("PrIcE"), MAX(quality) FOR partname IN ('prop' AS prop1, 'rudder'))
647647
"""
648648

649+
two_in_clauses_duckdb = """
650+
SELECT * FROM cities PIVOT (
651+
sum(population) AS total,
652+
count(population) AS count
653+
FOR
654+
year IN (2000, 2010)
655+
country IN ('NL', 'US')
656+
)
657+
"""
658+
659+
three_in_clauses_duckdb = """
660+
SELECT * FROM cities PIVOT (
661+
sum(population) AS total,
662+
count(population) AS count
663+
FOR
664+
year IN (2000, 2010)
665+
country IN ('NL', 'US')
666+
name IN ('Amsterdam', 'Seattle')
667+
)
668+
"""
669+
649670
query_to_column_names = {
650671
nothing_aliased: {
651672
"bigquery": ["prop", "rudder"],
@@ -707,13 +728,48 @@ def test_pivot_columns(self):
707728
'"rudder_max(quality)"',
708729
],
709730
},
731+
two_in_clauses_duckdb: {
732+
"duckdb": [
733+
'"2000_NL_total"',
734+
'"2000_NL_count"',
735+
'"2000_US_total"',
736+
'"2000_US_count"',
737+
'"2010_NL_total"',
738+
'"2010_NL_count"',
739+
'"2010_US_total"',
740+
'"2010_US_count"',
741+
],
742+
},
743+
three_in_clauses_duckdb: {
744+
"duckdb": [
745+
'"2000_NL_Amsterdam_total"',
746+
'"2000_NL_Amsterdam_count"',
747+
'"2000_NL_Seattle_total"',
748+
'"2000_NL_Seattle_count"',
749+
'"2000_US_Amsterdam_total"',
750+
'"2000_US_Amsterdam_count"',
751+
'"2000_US_Seattle_total"',
752+
'"2000_US_Seattle_count"',
753+
'"2010_NL_Amsterdam_total"',
754+
'"2010_NL_Amsterdam_count"',
755+
'"2010_NL_Seattle_total"',
756+
'"2010_NL_Seattle_count"',
757+
'"2010_US_Amsterdam_total"',
758+
'"2010_US_Amsterdam_count"',
759+
'"2010_US_Seattle_total"',
760+
'"2010_US_Seattle_count"',
761+
],
762+
},
710763
}
711764

712765
for query, dialect_columns in query_to_column_names.items():
713766
for dialect, expected_columns in dialect_columns.items():
714-
expr = parse_one(query, read=dialect)
715-
columns = expr.args["from"].this.args["pivots"][0].args["columns"]
716-
self.assertEqual(expected_columns, [col.sql(dialect=dialect) for col in columns])
767+
with self.subTest(f"Testing query '{query}' for dialect {dialect}"):
768+
expr = parse_one(query, read=dialect)
769+
columns = expr.args["from"].this.args["pivots"][0].args["columns"]
770+
self.assertEqual(
771+
expected_columns, [col.sql(dialect=dialect) for col in columns]
772+
)
717773

718774
def test_parse_nested(self):
719775
def warn_over_threshold(query: str, max_threshold: float = 0.2):

0 commit comments

Comments
 (0)