From 2d962d2daf7bee7c055732ce2382beef6514f977 Mon Sep 17 00:00:00 2001 From: Nnamdi Nwabuokei Date: Mon, 15 Dec 2025 15:17:03 +0100 Subject: [PATCH] chore(exasol): custom transformation of pivot clause in exasol dialect --- sqlglot/dialects/exasol.py | 206 +++++++++++++++++++++++++++++++++- tests/dialects/test_exasol.py | 108 ++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) diff --git a/sqlglot/dialects/exasol.py b/sqlglot/dialects/exasol.py index fd61dac18d..a3439e1cac 100644 --- a/sqlglot/dialects/exasol.py +++ b/sqlglot/dialects/exasol.py @@ -1,8 +1,9 @@ from __future__ import annotations import typing as t +from typing import Any -from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot import exp, generator, parser, tokens, transforms, Expression from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, @@ -17,6 +18,7 @@ no_last_day_sql, DATE_ADD_OR_SUB, ) +from sqlglot.expressions import Paren, Tuple from sqlglot.generator import unsupported_args from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -564,3 +566,205 @@ def rank_sql(self, expression: exp.Rank) -> str: if expression.args.get("expressions"): self.unsupported("Exasol does not support arguments in RANK") return self.func("RANK") + + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: + """ + If a table has PIVOTs attached, let the Pivot render a derived table . + """ + pivots = expression.args.get("pivots") or [] + if not pivots: + return super().table_sql(expression) + + if len(pivots) > 1: + self.unsupported("Multiple PIVOT clauses are not supported by Exasol") + return super().table_sql(expression) + pivot = pivots[0] + return self.sql(pivot) + + def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: + pivots = expression.args.get("pivots") or [] + if not pivots: + return super().subquery_sql(expression) + + if len(pivots) > 1: + self.unsupported("Multiple PIVOT clauses are not supported by Exasol") + return super().subquery_sql(expression) + pivot = pivots[0] + return self.sql(pivot) + + def pivot_sql(self, expression: exp.Pivot) -> str: + """ + Exasol does not support PIVOT, so we rewrite it. + + Rewrite: + SELECT ... FROM T PIVOT (...) + + Into: + SELECT ... FROM ( + SELECT , + + FROM T + GROUP BY + ) + """ + + if expression.unpivot: + self.unsupported("UNPIVOT is not supported in Exasol.") + return super().pivot_sql(expression) + source_relation = expression.find_ancestor(exp.From) + + if not source_relation: + return super().pivot_sql(expression) + + if isinstance(source_relation.this, exp.Table) or isinstance( + source_relation.this, exp.Subquery + ): + source_table_sql = ( + f"({self.sql(source_relation.this.this)})" + if isinstance(source_relation.this, exp.Subquery) + else self.sql(source_relation.this.this) + ) + source_alias_expr = source_relation.this.args.get("alias") + from_source_sql = ( + f"{source_table_sql} AS {self.sql(source_alias_expr)}" + if source_alias_expr + else source_table_sql + ) + source_name = ( + source_alias_expr.this + if isinstance(source_alias_expr, exp.TableAlias) + else source_table_sql + ) + else: + return super().pivot_sql(expression) + + aggregate_aliases = expression.expressions or [] + + if not aggregate_aliases: + return super().pivot_sql(expression) + + pivot_fields = expression.fields or [] + + if len(pivot_fields) != 1 or not isinstance(pivot_fields[0], exp.In): + return super().pivot_sql(expression) + + pivot_in_condition = pivot_fields[0] + pivot_key_expr = pivot_in_condition.this + pivot_values_nodes = pivot_in_condition.expressions or [] + + if not pivot_values_nodes: + return super().pivot_sql(expression) + + pivot_alias_expr = expression.args.get("alias") + + has_pivot_alias = isinstance(pivot_alias_expr, exp.TableAlias) + + def unwrap_tuple( + node: exp.Expression | None, + ) -> Tuple | Expression | None | Paren | Any: + if isinstance(node, exp.Tuple): + return node + if isinstance(node, exp.Paren) and isinstance(node.this, exp.Tuple): + return node.this + return None + + pivot_case_specs: list[tuple[exp.Expression, str, exp.Expression, str]] = [] + pivot_output_column_names: set[str] = set() + + for pivot_value in pivot_values_nodes: + pivot_value_alias = ( + pivot_value.alias_or_name + if isinstance(pivot_value, exp.PivotAlias) + else pivot_value.sql(dialect=self.dialect) + ) + pivot_value_expr = ( + pivot_value.this if isinstance(pivot_value, exp.PivotAlias) else pivot_value + ) + + for aggregate_alias in aggregate_aliases: + aggregate_func_expr = aggregate_alias.this + aggregate_func_name = ( + getattr(aggregate_func_expr, "key", None) or aggregate_func_expr.name + ).upper() + aggregate_input_expr = aggregate_func_expr.this + aggregate_result_suffix = aggregate_alias.alias + + output_column_name = ( + f"{pivot_value_alias}_{aggregate_result_suffix}" + if aggregate_result_suffix and len(aggregate_aliases) > 1 + else pivot_value_alias + ) + pivot_case_specs.append( + ( + pivot_value_expr, + aggregate_func_name, + aggregate_input_expr, + output_column_name, + ) + ) + + pivot_output_column_names.add(output_column_name) + + group_by_columns = list(expression.args.get("group") or []) + if not group_by_columns and has_pivot_alias and source_name: + outer_select = expression.find_ancestor(exp.Select) + + if isinstance(outer_select, exp.Select): + for projection in outer_select.expressions or []: + projected_expr = ( + projection.this if isinstance(projection, exp.Alias) else projection + ) + if not isinstance(projected_expr, exp.Column): + continue + + projected_column_name = projected_expr.name + + if projected_column_name in pivot_output_column_names: + continue + + group_by_columns.append( + exp.Column( + this=exp.to_identifier(projected_column_name, True), + table=exp.to_identifier(source_name), + ) + ) + + group_columns_sql = [self.sql(col) for col in group_by_columns] + select_list_sql_parts: list[str] = list(group_columns_sql) + + for ( + pivot_value_expr, + aggregate_func_name, + aggregate_input_expr, + output_column_name, + ) in pivot_case_specs: + key_tuple = unwrap_tuple(pivot_key_expr) + value_tuple = unwrap_tuple(pivot_value_expr) + + if key_tuple and value_tuple: + comparisons: list[str] = [] + for key_part, value_part in zip( + key_tuple.expressions or [], value_tuple.expressions or [] + ): + comparisons.append(f"{self.sql(key_part)} = {self.sql(value_part)}") + condition_sql = " AND ".join(comparisons) + else: + condition_sql = f"{self.sql(pivot_key_expr)} = {self.sql(pivot_value_expr)}" + aggregate_input_sql = self.sql(aggregate_input_expr) + output_column_sql = self.sql(exp.to_identifier(output_column_name, True)) + case_expr_sql = f"CASE WHEN {condition_sql} THEN {aggregate_input_sql} END" + select_list_sql_parts.append( + f"{aggregate_func_name}({case_expr_sql}) AS {output_column_sql}" + ) + + inner_select_sql = ", ".join(select_list_sql_parts) + + if group_columns_sql: + group_by_sql = ", ".join(group_columns_sql) + inner_query = ( + f"SELECT {inner_select_sql} FROM {from_source_sql} GROUP BY {group_by_sql}" + ) + else: + inner_query = f"SELECT {inner_select_sql} FROM {from_source_sql}" + pivot_alias_sql = f"{self.sql(pivot_alias_expr)}" if pivot_alias_expr else "" + return f"({inner_query}){pivot_alias_sql}" diff --git a/tests/dialects/test_exasol.py b/tests/dialects/test_exasol.py index bf5e9b8f05..3c516557ba 100644 --- a/tests/dialects/test_exasol.py +++ b/tests/dialects/test_exasol.py @@ -1,4 +1,6 @@ from tests.dialects.test_dialect import Validator +from sqlglot import parse_one +from sqlglot.optimizer import optimize class TestExasol(Validator): @@ -762,3 +764,109 @@ def test_local_prefix_for_alias(self): exasol_sql, write={"exasol": exasol_sql, "databricks": dbx_sql}, ) + + def test_pivot(self): + test_cases = [ + ( + "Single-column pivot rewrite", + """ + SELECT + "_0"."year" AS "year", + "_0"."region" AS "region", + "_0"."q1" AS "q1", + "_0"."q2" AS "q2", + "_0"."q3" AS "q3", + "_0"."q4" AS "q4" + FROM (SELECT "sales"."year", "sales"."region", SUM(CASE WHEN "sales"."quarter" = 1 THEN "sales"."sales" END) AS "q1", SUM(CASE WHEN "sales"."quarter" = 2 THEN "sales"."sales" END) AS "q2", SUM(CASE WHEN "sales"."quarter" = 3 THEN "sales"."sales" END) AS "q3", SUM(CASE WHEN "sales"."quarter" = 4 THEN "sales"."sales" END) AS "q4" FROM "sales" AS "sales" GROUP BY "sales"."year", "sales"."region")"_0" + """, + "SELECT year, region, q1, q2, q3, q4 FROM sales PIVOT (sum(sales) AS sales FOR quarter IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4))", + ), + ( + "Tuple pivot rewrite (multi-column FOR)", + """ + SELECT + "_0"."year" AS "year", + "_0"."q1_east" AS "q1_east", + "_0"."q1_west" AS "q1_west", + "_0"."q2_east" AS "q2_east", + "_0"."q2_west" AS "q2_west", + "_0"."q3_east" AS "q3_east", + "_0"."q3_west" AS "q3_west", + "_0"."q4_east" AS "q4_east", + "_0"."q4_west" AS "q4_west" + FROM (SELECT "sales"."year", SUM(CASE WHEN "sales"."quarter" = 1 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q1_east", SUM(CASE WHEN "sales"."quarter" = 1 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q1_west", SUM(CASE WHEN "sales"."quarter" = 2 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q2_east", SUM(CASE WHEN "sales"."quarter" = 2 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q2_west", SUM(CASE WHEN "sales"."quarter" = 3 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q3_east", SUM(CASE WHEN "sales"."quarter" = 3 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q3_west", SUM(CASE WHEN "sales"."quarter" = 4 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q4_east", SUM(CASE WHEN "sales"."quarter" = 4 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q4_west" FROM "sales" AS "sales" GROUP BY "sales"."year")"_0" + """, + """ + SELECT year, q1_east, q1_west, q2_east, q2_west, q3_east, q3_west, q4_east, q4_west + FROM sales + PIVOT (sum(sales) AS sales + FOR (quarter, region) + IN ((1, 'east') AS q1_east, (1, 'west') AS q1_west, (2, 'east') AS q2_east, (2, 'west') AS q2_west, + (3, 'east') AS q3_east, (3, 'west') AS q3_west, (4, 'east') AS q4_east, (4, 'west') AS q4_west)) + """, + ), + ( + "Pivot rewrite over derived table source", + """ + SELECT + "_0"."year" AS "year", + "_0"."q1" AS "q1", + "_0"."q2" AS "q2", + "_0"."q3" AS "q3", + "_0"."q4" AS "q4" + FROM (SELECT "s"."year", SUM(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1", SUM(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2", SUM(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3", SUM(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4" FROM (SELECT + "sales"."year" AS "year", + "sales"."quarter" AS "quarter", + "sales"."sales" AS "sales" + FROM "sales" AS "sales") AS "s" GROUP BY "s"."year")"_0" + """, + """ + SELECT year, q1, q2, q3, q4 + FROM (SELECT year, quarter, sales FROM sales) AS s + PIVOT (sum(sales) AS sales + FOR quarter + IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4)) + """, + ), + ( + "Pivot rewrite with multiple aggregates", + """ + SELECT + "_0"."year" AS "year", + "_0"."q1_total" AS "q1_total", + "_0"."q1_avg" AS "q1_avg", + "_0"."q2_total" AS "q2_total", + "_0"."q2_avg" AS "q2_avg", + "_0"."q3_total" AS "q3_total", + "_0"."q3_avg" AS "q3_avg", + "_0"."q4_total" AS "q4_total", + "_0"."q4_avg" AS "q4_avg" + FROM (SELECT "s"."year", SUM(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1_total", AVG(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1_avg", SUM(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2_total", AVG(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2_avg", SUM(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3_total", AVG(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3_avg", SUM(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4_total", AVG(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4_avg" FROM (SELECT + "sales"."year" AS "year", + "sales"."quarter" AS "quarter", + "sales"."sales" AS "sales" + FROM "sales" AS "sales") AS "s" GROUP BY "s"."year")"_0" + """, + """ + SELECT year, q1_total, q1_avg, q2_total, q2_avg, q3_total, q3_avg, q4_total, q4_avg + FROM (SELECT year, quarter, sales FROM sales) AS s + PIVOT (sum(sales) AS total, avg(sales) AS avg + FOR quarter + IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4)) + """, + ), + ] + for title, exasol_sql, dbx_sql in test_cases: + with self.subTest(clause=title): + schema = { + "sales": {"year": "INT", "quarter": "INT", "region": "STRING", "sales": "INT"} + } + expr = parse_one(dbx_sql, read="databricks") + optimize_expr = optimize(expr, schema) + + transpile = optimize_expr.sql(dialect="exasol") + + self.assertEqual( + parse_one(transpile, read="exasol").sql(dialect="exasol"), + parse_one(exasol_sql, read="exasol").sql(dialect="exasol"), + )