-
Notifications
You must be signed in to change notification settings - Fork 1k
chore(exasol): custom transformation of pivot clause in exasol dialect #6558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import typing as t | ||
| from typing import Any | ||
|
|
||
| from sqlglot import exp, generator, parser, tokens, transforms | ||
| from sqlglot import exp, generator, parser, tokens, transforms, Expression | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets not import |
||
| from sqlglot.dialects.dialect import ( | ||
| Dialect, | ||
| NormalizationStrategy, | ||
|
|
@@ -17,6 +18,7 @@ | |
| no_last_day_sql, | ||
| DATE_ADD_OR_SUB, | ||
| ) | ||
| from sqlglot.expressions import Paren, Tuple | ||
| from sqlglot.generator import unsupported_args | ||
| from sqlglot.helper import seq_get | ||
| from sqlglot.tokens import TokenType | ||
|
|
@@ -564,3 +566,205 @@ def rank_sql(self, expression: exp.Rank) -> str: | |
| if expression.args.get("expressions"): | ||
| self.unsupported("Exasol does not support arguments in RANK") | ||
| return self.func("RANK") | ||
|
|
||
| def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: | ||
| """ | ||
| If a table has PIVOTs attached, let the Pivot render a derived table . | ||
| """ | ||
| pivots = expression.args.get("pivots") or [] | ||
| if not pivots: | ||
| return super().table_sql(expression) | ||
|
|
||
| if len(pivots) > 1: | ||
| self.unsupported("Multiple PIVOT clauses are not supported by Exasol") | ||
| return super().table_sql(expression) | ||
| pivot = pivots[0] | ||
| return self.sql(pivot) | ||
|
Comment on lines
+581
to
+582
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we generating the pivot instead of the table here? That feels wrong. Shouldn't the whole pivot transpilation happen with a |
||
|
|
||
| def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: | ||
| pivots = expression.args.get("pivots") or [] | ||
|
Comment on lines
+584
to
+585
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we test subqueries with pivots? Does the transformation work in any nested level? |
||
| if not pivots: | ||
| return super().subquery_sql(expression) | ||
|
|
||
| if len(pivots) > 1: | ||
| self.unsupported("Multiple PIVOT clauses are not supported by Exasol") | ||
| return super().subquery_sql(expression) | ||
| pivot = pivots[0] | ||
| return self.sql(pivot) | ||
|
Comment on lines
+592
to
+593
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. |
||
|
|
||
| def pivot_sql(self, expression: exp.Pivot) -> str: | ||
| """ | ||
| Exasol does not support PIVOT, so we rewrite it. | ||
|
|
||
| Rewrite: | ||
| SELECT ... FROM T PIVOT (...) | ||
|
|
||
| Into: | ||
| SELECT ... FROM ( | ||
| SELECT <group cols>, | ||
| <agg(CASE WHEN ...)> | ||
| FROM T | ||
| GROUP BY <group cols> | ||
| ) | ||
|
Comment on lines
+603
to
+608
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conversion is easy to mess up if you change the scopes by moving both the pivot and |
||
| """ | ||
|
|
||
| if expression.unpivot: | ||
| self.unsupported("UNPIVOT is not supported in Exasol.") | ||
| return super().pivot_sql(expression) | ||
| source_relation = expression.find_ancestor(exp.From) | ||
|
|
||
| if not source_relation: | ||
| return super().pivot_sql(expression) | ||
|
Comment on lines
+616
to
+617
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we traverse ancestors up to |
||
|
|
||
| if isinstance(source_relation.this, exp.Table) or isinstance( | ||
| source_relation.this, exp.Subquery | ||
| ): | ||
|
Comment on lines
+619
to
+621
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can merge together instance checks e.g |
||
| source_table_sql = ( | ||
| f"({self.sql(source_relation.this.this)})" | ||
| if isinstance(source_relation.this, exp.Subquery) | ||
| else self.sql(source_relation.this.this) | ||
| ) | ||
| source_alias_expr = source_relation.this.args.get("alias") | ||
| from_source_sql = ( | ||
| f"{source_table_sql} AS {self.sql(source_alias_expr)}" | ||
| if source_alias_expr | ||
| else source_table_sql | ||
| ) | ||
| source_name = ( | ||
| source_alias_expr.this | ||
| if isinstance(source_alias_expr, exp.TableAlias) | ||
| else source_table_sql | ||
|
Comment on lines
+622
to
+636
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not be generating strings, we can instead build AST nodes and let the generator do the work |
||
| ) | ||
| else: | ||
| return super().pivot_sql(expression) | ||
|
|
||
| aggregate_aliases = expression.expressions or [] | ||
|
|
||
| if not aggregate_aliases: | ||
| return super().pivot_sql(expression) | ||
|
Comment on lines
+643
to
+644
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not have aggregate expressions? |
||
|
|
||
| pivot_fields = expression.fields or [] | ||
|
|
||
| if len(pivot_fields) != 1 or not isinstance(pivot_fields[0], exp.In): | ||
| return super().pivot_sql(expression) | ||
|
|
||
| pivot_in_condition = pivot_fields[0] | ||
| pivot_key_expr = pivot_in_condition.this | ||
| pivot_values_nodes = pivot_in_condition.expressions or [] | ||
|
|
||
| if not pivot_values_nodes: | ||
| return super().pivot_sql(expression) | ||
|
|
||
| pivot_alias_expr = expression.args.get("alias") | ||
|
|
||
| has_pivot_alias = isinstance(pivot_alias_expr, exp.TableAlias) | ||
|
|
||
| def unwrap_tuple( | ||
| node: exp.Expression | None, | ||
| ) -> Tuple | Expression | None | Paren | Any: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These type hints here are confused:
|
||
| if isinstance(node, exp.Tuple): | ||
| return node | ||
| if isinstance(node, exp.Paren) and isinstance(node.this, exp.Tuple): | ||
| return node.this | ||
| return None | ||
|
|
||
| pivot_case_specs: list[tuple[exp.Expression, str, exp.Expression, str]] = [] | ||
| pivot_output_column_names: set[str] = set() | ||
|
|
||
| for pivot_value in pivot_values_nodes: | ||
| pivot_value_alias = ( | ||
| pivot_value.alias_or_name | ||
| if isinstance(pivot_value, exp.PivotAlias) | ||
| else pivot_value.sql(dialect=self.dialect) | ||
| ) | ||
| pivot_value_expr = ( | ||
| pivot_value.this if isinstance(pivot_value, exp.PivotAlias) else pivot_value | ||
| ) | ||
|
|
||
| for aggregate_alias in aggregate_aliases: | ||
| aggregate_func_expr = aggregate_alias.this | ||
| aggregate_func_name = ( | ||
| getattr(aggregate_func_expr, "key", None) or aggregate_func_expr.name | ||
| ).upper() | ||
| aggregate_input_expr = aggregate_func_expr.this | ||
| aggregate_result_suffix = aggregate_alias.alias | ||
|
|
||
| output_column_name = ( | ||
| f"{pivot_value_alias}_{aggregate_result_suffix}" | ||
| if aggregate_result_suffix and len(aggregate_aliases) > 1 | ||
| else pivot_value_alias | ||
| ) | ||
| pivot_case_specs.append( | ||
| ( | ||
| pivot_value_expr, | ||
| aggregate_func_name, | ||
| aggregate_input_expr, | ||
| output_column_name, | ||
| ) | ||
| ) | ||
|
|
||
| pivot_output_column_names.add(output_column_name) | ||
|
|
||
| group_by_columns = list(expression.args.get("group") or []) | ||
| if not group_by_columns and has_pivot_alias and source_name: | ||
| outer_select = expression.find_ancestor(exp.Select) | ||
|
|
||
| if isinstance(outer_select, exp.Select): | ||
| for projection in outer_select.expressions or []: | ||
| projected_expr = ( | ||
| projection.this if isinstance(projection, exp.Alias) else projection | ||
| ) | ||
| if not isinstance(projected_expr, exp.Column): | ||
| continue | ||
|
|
||
| projected_column_name = projected_expr.name | ||
|
|
||
| if projected_column_name in pivot_output_column_names: | ||
| continue | ||
|
|
||
| group_by_columns.append( | ||
| exp.Column( | ||
| this=exp.to_identifier(projected_column_name, True), | ||
| table=exp.to_identifier(source_name), | ||
| ) | ||
| ) | ||
|
|
||
| group_columns_sql = [self.sql(col) for col in group_by_columns] | ||
| select_list_sql_parts: list[str] = list(group_columns_sql) | ||
|
|
||
| for ( | ||
| pivot_value_expr, | ||
| aggregate_func_name, | ||
| aggregate_input_expr, | ||
| output_column_name, | ||
| ) in pivot_case_specs: | ||
| key_tuple = unwrap_tuple(pivot_key_expr) | ||
| value_tuple = unwrap_tuple(pivot_value_expr) | ||
|
|
||
| if key_tuple and value_tuple: | ||
| comparisons: list[str] = [] | ||
| for key_part, value_part in zip( | ||
| key_tuple.expressions or [], value_tuple.expressions or [] | ||
| ): | ||
| comparisons.append(f"{self.sql(key_part)} = {self.sql(value_part)}") | ||
| condition_sql = " AND ".join(comparisons) | ||
| else: | ||
| condition_sql = f"{self.sql(pivot_key_expr)} = {self.sql(pivot_value_expr)}" | ||
| aggregate_input_sql = self.sql(aggregate_input_expr) | ||
| output_column_sql = self.sql(exp.to_identifier(output_column_name, True)) | ||
| case_expr_sql = f"CASE WHEN {condition_sql} THEN {aggregate_input_sql} END" | ||
| select_list_sql_parts.append( | ||
| f"{aggregate_func_name}({case_expr_sql}) AS {output_column_sql}" | ||
| ) | ||
|
|
||
| inner_select_sql = ", ".join(select_list_sql_parts) | ||
|
Comment on lines
+750
to
+760
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto about creating strings, we should try to create AST nodes such as much possible and delegating their generation to |
||
|
|
||
| if group_columns_sql: | ||
| group_by_sql = ", ".join(group_columns_sql) | ||
| inner_query = ( | ||
| f"SELECT {inner_select_sql} FROM {from_source_sql} GROUP BY {group_by_sql}" | ||
| ) | ||
| else: | ||
| inner_query = f"SELECT {inner_select_sql} FROM {from_source_sql}" | ||
| pivot_alias_sql = f"{self.sql(pivot_alias_expr)}" if pivot_alias_expr else "" | ||
| return f"({inner_query}){pivot_alias_sql}" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| from tests.dialects.test_dialect import Validator | ||
| from sqlglot import parse_one | ||
| from sqlglot.optimizer import optimize | ||
|
|
||
|
|
||
| class TestExasol(Validator): | ||
|
|
@@ -762,3 +764,109 @@ def test_local_prefix_for_alias(self): | |
| exasol_sql, | ||
| write={"exasol": exasol_sql, "databricks": dbx_sql}, | ||
| ) | ||
|
|
||
| def test_pivot(self): | ||
| test_cases = [ | ||
| ( | ||
| "Single-column pivot rewrite", | ||
| """ | ||
| SELECT | ||
| "_0"."year" AS "year", | ||
| "_0"."region" AS "region", | ||
| "_0"."q1" AS "q1", | ||
| "_0"."q2" AS "q2", | ||
| "_0"."q3" AS "q3", | ||
| "_0"."q4" AS "q4" | ||
| FROM (SELECT "sales"."year", "sales"."region", SUM(CASE WHEN "sales"."quarter" = 1 THEN "sales"."sales" END) AS "q1", SUM(CASE WHEN "sales"."quarter" = 2 THEN "sales"."sales" END) AS "q2", SUM(CASE WHEN "sales"."quarter" = 3 THEN "sales"."sales" END) AS "q3", SUM(CASE WHEN "sales"."quarter" = 4 THEN "sales"."sales" END) AS "q4" FROM "sales" AS "sales" GROUP BY "sales"."year", "sales"."region")"_0" | ||
| """, | ||
|
Comment on lines
+770
to
+781
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you verified that these work in Exasol? I can see that they're the Databricks
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would look into it @VaggelisD. Please does DBX also mean Databricks?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is right! Ideally your code here would work in ANSI SQL and we could reuse it in other dialects that don't support |
||
| "SELECT year, region, q1, q2, q3, q4 FROM sales PIVOT (sum(sales) AS sales FOR quarter IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4))", | ||
| ), | ||
| ( | ||
| "Tuple pivot rewrite (multi-column FOR)", | ||
| """ | ||
| SELECT | ||
| "_0"."year" AS "year", | ||
| "_0"."q1_east" AS "q1_east", | ||
| "_0"."q1_west" AS "q1_west", | ||
| "_0"."q2_east" AS "q2_east", | ||
| "_0"."q2_west" AS "q2_west", | ||
| "_0"."q3_east" AS "q3_east", | ||
| "_0"."q3_west" AS "q3_west", | ||
| "_0"."q4_east" AS "q4_east", | ||
| "_0"."q4_west" AS "q4_west" | ||
| FROM (SELECT "sales"."year", SUM(CASE WHEN "sales"."quarter" = 1 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q1_east", SUM(CASE WHEN "sales"."quarter" = 1 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q1_west", SUM(CASE WHEN "sales"."quarter" = 2 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q2_east", SUM(CASE WHEN "sales"."quarter" = 2 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q2_west", SUM(CASE WHEN "sales"."quarter" = 3 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q3_east", SUM(CASE WHEN "sales"."quarter" = 3 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q3_west", SUM(CASE WHEN "sales"."quarter" = 4 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q4_east", SUM(CASE WHEN "sales"."quarter" = 4 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q4_west" FROM "sales" AS "sales" GROUP BY "sales"."year")"_0" | ||
| """, | ||
| """ | ||
| SELECT year, q1_east, q1_west, q2_east, q2_west, q3_east, q3_west, q4_east, q4_west | ||
| FROM sales | ||
| PIVOT (sum(sales) AS sales | ||
| FOR (quarter, region) | ||
| IN ((1, 'east') AS q1_east, (1, 'west') AS q1_west, (2, 'east') AS q2_east, (2, 'west') AS q2_west, | ||
| (3, 'east') AS q3_east, (3, 'west') AS q3_west, (4, 'east') AS q4_east, (4, 'west') AS q4_west)) | ||
| """, | ||
| ), | ||
| ( | ||
| "Pivot rewrite over derived table source", | ||
| """ | ||
| SELECT | ||
| "_0"."year" AS "year", | ||
| "_0"."q1" AS "q1", | ||
| "_0"."q2" AS "q2", | ||
| "_0"."q3" AS "q3", | ||
| "_0"."q4" AS "q4" | ||
| FROM (SELECT "s"."year", SUM(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1", SUM(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2", SUM(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3", SUM(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4" FROM (SELECT | ||
| "sales"."year" AS "year", | ||
| "sales"."quarter" AS "quarter", | ||
| "sales"."sales" AS "sales" | ||
| FROM "sales" AS "sales") AS "s" GROUP BY "s"."year")"_0" | ||
| """, | ||
| """ | ||
| SELECT year, q1, q2, q3, q4 | ||
| FROM (SELECT year, quarter, sales FROM sales) AS s | ||
| PIVOT (sum(sales) AS sales | ||
| FOR quarter | ||
| IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4)) | ||
| """, | ||
| ), | ||
| ( | ||
| "Pivot rewrite with multiple aggregates", | ||
| """ | ||
| SELECT | ||
| "_0"."year" AS "year", | ||
| "_0"."q1_total" AS "q1_total", | ||
| "_0"."q1_avg" AS "q1_avg", | ||
| "_0"."q2_total" AS "q2_total", | ||
| "_0"."q2_avg" AS "q2_avg", | ||
| "_0"."q3_total" AS "q3_total", | ||
| "_0"."q3_avg" AS "q3_avg", | ||
| "_0"."q4_total" AS "q4_total", | ||
| "_0"."q4_avg" AS "q4_avg" | ||
| FROM (SELECT "s"."year", SUM(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1_total", AVG(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1_avg", SUM(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2_total", AVG(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2_avg", SUM(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3_total", AVG(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3_avg", SUM(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4_total", AVG(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4_avg" FROM (SELECT | ||
| "sales"."year" AS "year", | ||
| "sales"."quarter" AS "quarter", | ||
| "sales"."sales" AS "sales" | ||
| FROM "sales" AS "sales") AS "s" GROUP BY "s"."year")"_0" | ||
| """, | ||
| """ | ||
| SELECT year, q1_total, q1_avg, q2_total, q2_avg, q3_total, q3_avg, q4_total, q4_avg | ||
| FROM (SELECT year, quarter, sales FROM sales) AS s | ||
| PIVOT (sum(sales) AS total, avg(sales) AS avg | ||
| FOR quarter | ||
| IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4)) | ||
| """, | ||
| ), | ||
| ] | ||
| for title, exasol_sql, dbx_sql in test_cases: | ||
| with self.subTest(clause=title): | ||
| schema = { | ||
| "sales": {"year": "INT", "quarter": "INT", "region": "STRING", "sales": "INT"} | ||
| } | ||
| expr = parse_one(dbx_sql, read="databricks") | ||
| optimize_expr = optimize(expr, schema) | ||
|
|
||
| transpile = optimize_expr.sql(dialect="exasol") | ||
|
|
||
|
Comment on lines
+865
to
+868
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we optimizing and reparsing? |
||
| self.assertEqual( | ||
| parse_one(transpile, read="exasol").sql(dialect="exasol"), | ||
| parse_one(exasol_sql, read="exasol").sql(dialect="exasol"), | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we import
Anyhere? It's not a pattern in SQLGLot