Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 205 additions & 1 deletion sqlglot/dialects/exasol.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import typing as t
from typing import Any
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we import Any here? It's not a pattern in SQLGLot


from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot import exp, generator, parser, tokens, transforms, Expression
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not import Expression, we can do exp.Expression

from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
Expand All @@ -17,6 +18,7 @@
no_last_day_sql,
DATE_ADD_OR_SUB,
)
from sqlglot.expressions import Paren, Tuple
from sqlglot.generator import unsupported_args
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -564,3 +566,205 @@ def rank_sql(self, expression: exp.Rank) -> str:
if expression.args.get("expressions"):
self.unsupported("Exasol does not support arguments in RANK")
return self.func("RANK")

def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
"""
If a table has PIVOTs attached, let the Pivot render a derived table .
"""
pivots = expression.args.get("pivots") or []
if not pivots:
return super().table_sql(expression)

if len(pivots) > 1:
self.unsupported("Multiple PIVOT clauses are not supported by Exasol")
return super().table_sql(expression)
pivot = pivots[0]
return self.sql(pivot)
Comment on lines +581 to +582
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we generating the pivot instead of the table here? That feels wrong. Shouldn't the whole pivot transpilation happen with a preprocess call or something instead of this?


def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
pivots = expression.args.get("pivots") or []
Comment on lines +584 to +585
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we test subqueries with pivots? Does the transformation work in any nested level?

if not pivots:
return super().subquery_sql(expression)

if len(pivots) > 1:
self.unsupported("Multiple PIVOT clauses are not supported by Exasol")
return super().subquery_sql(expression)
pivot = pivots[0]
return self.sql(pivot)
Comment on lines +592 to +593
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.


def pivot_sql(self, expression: exp.Pivot) -> str:
"""
Exasol does not support PIVOT, so we rewrite it.

Rewrite:
SELECT ... FROM T PIVOT (...)

Into:
SELECT ... FROM (
SELECT <group cols>,
<agg(CASE WHEN ...)>
FROM T
GROUP BY <group cols>
)
Comment on lines +603 to +608
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion is easy to mess up if you change the scopes by moving both the pivot and T inside of a new derived table. What if you have columns in the original scope already qualified with either T or the pivot's alias, if any? Have you thought / handled cases like this?

"""

if expression.unpivot:
self.unsupported("UNPIVOT is not supported in Exasol.")
return super().pivot_sql(expression)
source_relation = expression.find_ancestor(exp.From)

if not source_relation:
return super().pivot_sql(expression)
Comment on lines +616 to +617
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we traverse ancestors up to FROM? We can do pivot.parent to get the table afaict


if isinstance(source_relation.this, exp.Table) or isinstance(
source_relation.this, exp.Subquery
):
Comment on lines +619 to +621
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge together instance checks e.g if instance(expr, (exp.Table, exp.Subquery))

source_table_sql = (
f"({self.sql(source_relation.this.this)})"
if isinstance(source_relation.this, exp.Subquery)
else self.sql(source_relation.this.this)
)
source_alias_expr = source_relation.this.args.get("alias")
from_source_sql = (
f"{source_table_sql} AS {self.sql(source_alias_expr)}"
if source_alias_expr
else source_table_sql
)
source_name = (
source_alias_expr.this
if isinstance(source_alias_expr, exp.TableAlias)
else source_table_sql
Comment on lines +622 to +636
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be generating strings, we can instead build AST nodes and let the generator do the work

)
else:
return super().pivot_sql(expression)

aggregate_aliases = expression.expressions or []

if not aggregate_aliases:
return super().pivot_sql(expression)
Comment on lines +643 to +644
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not have aggregate expressions?


pivot_fields = expression.fields or []

if len(pivot_fields) != 1 or not isinstance(pivot_fields[0], exp.In):
return super().pivot_sql(expression)

pivot_in_condition = pivot_fields[0]
pivot_key_expr = pivot_in_condition.this
pivot_values_nodes = pivot_in_condition.expressions or []

if not pivot_values_nodes:
return super().pivot_sql(expression)

pivot_alias_expr = expression.args.get("alias")

has_pivot_alias = isinstance(pivot_alias_expr, exp.TableAlias)

def unwrap_tuple(
node: exp.Expression | None,
) -> Tuple | Expression | None | Paren | Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These type hints here are confused:

  • Tuple and Paren are subsets of Expression; We'd also never return Paren
  • Any is not used in SQLGlot

if isinstance(node, exp.Tuple):
return node
if isinstance(node, exp.Paren) and isinstance(node.this, exp.Tuple):
return node.this
return None

pivot_case_specs: list[tuple[exp.Expression, str, exp.Expression, str]] = []
pivot_output_column_names: set[str] = set()

for pivot_value in pivot_values_nodes:
pivot_value_alias = (
pivot_value.alias_or_name
if isinstance(pivot_value, exp.PivotAlias)
else pivot_value.sql(dialect=self.dialect)
)
pivot_value_expr = (
pivot_value.this if isinstance(pivot_value, exp.PivotAlias) else pivot_value
)

for aggregate_alias in aggregate_aliases:
aggregate_func_expr = aggregate_alias.this
aggregate_func_name = (
getattr(aggregate_func_expr, "key", None) or aggregate_func_expr.name
).upper()
aggregate_input_expr = aggregate_func_expr.this
aggregate_result_suffix = aggregate_alias.alias

output_column_name = (
f"{pivot_value_alias}_{aggregate_result_suffix}"
if aggregate_result_suffix and len(aggregate_aliases) > 1
else pivot_value_alias
)
pivot_case_specs.append(
(
pivot_value_expr,
aggregate_func_name,
aggregate_input_expr,
output_column_name,
)
)

pivot_output_column_names.add(output_column_name)

group_by_columns = list(expression.args.get("group") or [])
if not group_by_columns and has_pivot_alias and source_name:
outer_select = expression.find_ancestor(exp.Select)

if isinstance(outer_select, exp.Select):
for projection in outer_select.expressions or []:
projected_expr = (
projection.this if isinstance(projection, exp.Alias) else projection
)
if not isinstance(projected_expr, exp.Column):
continue

projected_column_name = projected_expr.name

if projected_column_name in pivot_output_column_names:
continue

group_by_columns.append(
exp.Column(
this=exp.to_identifier(projected_column_name, True),
table=exp.to_identifier(source_name),
)
)

group_columns_sql = [self.sql(col) for col in group_by_columns]
select_list_sql_parts: list[str] = list(group_columns_sql)

for (
pivot_value_expr,
aggregate_func_name,
aggregate_input_expr,
output_column_name,
) in pivot_case_specs:
key_tuple = unwrap_tuple(pivot_key_expr)
value_tuple = unwrap_tuple(pivot_value_expr)

if key_tuple and value_tuple:
comparisons: list[str] = []
for key_part, value_part in zip(
key_tuple.expressions or [], value_tuple.expressions or []
):
comparisons.append(f"{self.sql(key_part)} = {self.sql(value_part)}")
condition_sql = " AND ".join(comparisons)
else:
condition_sql = f"{self.sql(pivot_key_expr)} = {self.sql(pivot_value_expr)}"
aggregate_input_sql = self.sql(aggregate_input_expr)
output_column_sql = self.sql(exp.to_identifier(output_column_name, True))
case_expr_sql = f"CASE WHEN {condition_sql} THEN {aggregate_input_sql} END"
select_list_sql_parts.append(
f"{aggregate_func_name}({case_expr_sql}) AS {output_column_sql}"
)

inner_select_sql = ", ".join(select_list_sql_parts)
Comment on lines +750 to +760
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto about creating strings, we should try to create AST nodes such as much possible and delegating their generation to self.sql(...)


if group_columns_sql:
group_by_sql = ", ".join(group_columns_sql)
inner_query = (
f"SELECT {inner_select_sql} FROM {from_source_sql} GROUP BY {group_by_sql}"
)
else:
inner_query = f"SELECT {inner_select_sql} FROM {from_source_sql}"
pivot_alias_sql = f"{self.sql(pivot_alias_expr)}" if pivot_alias_expr else ""
return f"({inner_query}){pivot_alias_sql}"
108 changes: 108 additions & 0 deletions tests/dialects/test_exasol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from tests.dialects.test_dialect import Validator
from sqlglot import parse_one
from sqlglot.optimizer import optimize


class TestExasol(Validator):
Expand Down Expand Up @@ -762,3 +764,109 @@ def test_local_prefix_for_alias(self):
exasol_sql,
write={"exasol": exasol_sql, "databricks": dbx_sql},
)

def test_pivot(self):
test_cases = [
(
"Single-column pivot rewrite",
"""
SELECT
"_0"."year" AS "year",
"_0"."region" AS "region",
"_0"."q1" AS "q1",
"_0"."q2" AS "q2",
"_0"."q3" AS "q3",
"_0"."q4" AS "q4"
FROM (SELECT "sales"."year", "sales"."region", SUM(CASE WHEN "sales"."quarter" = 1 THEN "sales"."sales" END) AS "q1", SUM(CASE WHEN "sales"."quarter" = 2 THEN "sales"."sales" END) AS "q2", SUM(CASE WHEN "sales"."quarter" = 3 THEN "sales"."sales" END) AS "q3", SUM(CASE WHEN "sales"."quarter" = 4 THEN "sales"."sales" END) AS "q4" FROM "sales" AS "sales" GROUP BY "sales"."year", "sales"."region")"_0"
""",
Comment on lines +770 to +781
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified that these work in Exasol? I can see that they're the Databricks PIVOT examples, but these queries don't work in Spark/DBX afaict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would look into it @VaggelisD. Please does DBX also mean Databricks?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is right! Ideally your code here would work in ANSI SQL and we could reuse it in other dialects that don't support PIVOT.

"SELECT year, region, q1, q2, q3, q4 FROM sales PIVOT (sum(sales) AS sales FOR quarter IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4))",
),
(
"Tuple pivot rewrite (multi-column FOR)",
"""
SELECT
"_0"."year" AS "year",
"_0"."q1_east" AS "q1_east",
"_0"."q1_west" AS "q1_west",
"_0"."q2_east" AS "q2_east",
"_0"."q2_west" AS "q2_west",
"_0"."q3_east" AS "q3_east",
"_0"."q3_west" AS "q3_west",
"_0"."q4_east" AS "q4_east",
"_0"."q4_west" AS "q4_west"
FROM (SELECT "sales"."year", SUM(CASE WHEN "sales"."quarter" = 1 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q1_east", SUM(CASE WHEN "sales"."quarter" = 1 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q1_west", SUM(CASE WHEN "sales"."quarter" = 2 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q2_east", SUM(CASE WHEN "sales"."quarter" = 2 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q2_west", SUM(CASE WHEN "sales"."quarter" = 3 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q3_east", SUM(CASE WHEN "sales"."quarter" = 3 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q3_west", SUM(CASE WHEN "sales"."quarter" = 4 AND "sales"."region" = 'east' THEN "sales"."sales" END) AS "q4_east", SUM(CASE WHEN "sales"."quarter" = 4 AND "sales"."region" = 'west' THEN "sales"."sales" END) AS "q4_west" FROM "sales" AS "sales" GROUP BY "sales"."year")"_0"
""",
"""
SELECT year, q1_east, q1_west, q2_east, q2_west, q3_east, q3_west, q4_east, q4_west
FROM sales
PIVOT (sum(sales) AS sales
FOR (quarter, region)
IN ((1, 'east') AS q1_east, (1, 'west') AS q1_west, (2, 'east') AS q2_east, (2, 'west') AS q2_west,
(3, 'east') AS q3_east, (3, 'west') AS q3_west, (4, 'east') AS q4_east, (4, 'west') AS q4_west))
""",
),
(
"Pivot rewrite over derived table source",
"""
SELECT
"_0"."year" AS "year",
"_0"."q1" AS "q1",
"_0"."q2" AS "q2",
"_0"."q3" AS "q3",
"_0"."q4" AS "q4"
FROM (SELECT "s"."year", SUM(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1", SUM(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2", SUM(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3", SUM(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4" FROM (SELECT
"sales"."year" AS "year",
"sales"."quarter" AS "quarter",
"sales"."sales" AS "sales"
FROM "sales" AS "sales") AS "s" GROUP BY "s"."year")"_0"
""",
"""
SELECT year, q1, q2, q3, q4
FROM (SELECT year, quarter, sales FROM sales) AS s
PIVOT (sum(sales) AS sales
FOR quarter
IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4))
""",
),
(
"Pivot rewrite with multiple aggregates",
"""
SELECT
"_0"."year" AS "year",
"_0"."q1_total" AS "q1_total",
"_0"."q1_avg" AS "q1_avg",
"_0"."q2_total" AS "q2_total",
"_0"."q2_avg" AS "q2_avg",
"_0"."q3_total" AS "q3_total",
"_0"."q3_avg" AS "q3_avg",
"_0"."q4_total" AS "q4_total",
"_0"."q4_avg" AS "q4_avg"
FROM (SELECT "s"."year", SUM(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1_total", AVG(CASE WHEN "s"."quarter" = 1 THEN "s"."sales" END) AS "q1_avg", SUM(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2_total", AVG(CASE WHEN "s"."quarter" = 2 THEN "s"."sales" END) AS "q2_avg", SUM(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3_total", AVG(CASE WHEN "s"."quarter" = 3 THEN "s"."sales" END) AS "q3_avg", SUM(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4_total", AVG(CASE WHEN "s"."quarter" = 4 THEN "s"."sales" END) AS "q4_avg" FROM (SELECT
"sales"."year" AS "year",
"sales"."quarter" AS "quarter",
"sales"."sales" AS "sales"
FROM "sales" AS "sales") AS "s" GROUP BY "s"."year")"_0"
""",
"""
SELECT year, q1_total, q1_avg, q2_total, q2_avg, q3_total, q3_avg, q4_total, q4_avg
FROM (SELECT year, quarter, sales FROM sales) AS s
PIVOT (sum(sales) AS total, avg(sales) AS avg
FOR quarter
IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4))
""",
),
]
for title, exasol_sql, dbx_sql in test_cases:
with self.subTest(clause=title):
schema = {
"sales": {"year": "INT", "quarter": "INT", "region": "STRING", "sales": "INT"}
}
expr = parse_one(dbx_sql, read="databricks")
optimize_expr = optimize(expr, schema)

transpile = optimize_expr.sql(dialect="exasol")

Comment on lines +865 to +868
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we optimizing and reparsing?

self.assertEqual(
parse_one(transpile, read="exasol").sql(dialect="exasol"),
parse_one(exasol_sql, read="exasol").sql(dialect="exasol"),
)