Skip to content

Commit b809a2b

Browse files
committed
Feat(optimizer): canonicalize table aliases
1 parent 8201062 commit b809a2b

File tree

3 files changed

+145
-54
lines changed

3 files changed

+145
-54
lines changed

sqlglot/optimizer/qualify.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def qualify(
3030
validate_qualify_columns: bool = True,
3131
quote_identifiers: bool = True,
3232
identify: bool = True,
33+
canonicalize_table_aliases: bool = False,
3334
on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
3435
) -> exp.Expression:
3536
"""
@@ -62,6 +63,8 @@ def qualify(
6263
This step is necessary to ensure correctness for case sensitive queries.
6364
But this flag is provided in case this step is performed at a later time.
6465
identify: If True, quote all identifiers, else only necessary ones.
66+
canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
67+
instead of preserving table names.
6568
on_qualify: Callback after a table has been qualified.
6669
6770
Returns:
@@ -81,6 +84,7 @@ def qualify(
8184
catalog=catalog,
8285
dialect=dialect,
8386
on_qualify=on_qualify,
87+
canonicalize=canonicalize_table_aliases,
8488
)
8589

8690
if isolate_tables:

sqlglot/optimizer/qualify_tables.py

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from sqlglot import exp
66
from sqlglot.dialects.dialect import Dialect, DialectType
7-
from sqlglot.helper import name_sequence
7+
from sqlglot.helper import name_sequence, seq_get
88
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
99
from sqlglot.optimizer.scope import Scope, traverse_scope
1010

@@ -18,6 +18,7 @@ def qualify_tables(
1818
catalog: t.Optional[str | exp.Identifier] = None,
1919
on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
2020
dialect: DialectType = None,
21+
canonicalize: bool = False,
2122
) -> E:
2223
"""
2324
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -39,13 +40,15 @@ def qualify_tables(
3940
catalog: Catalog name
4041
on_qualify: Callback after a table has been qualified.
4142
dialect: The dialect to parse catalog and schema into.
43+
canonicalize: Whether to use canonical aliases (_0, _1, ...) for all sources
44+
instead of preserving table names. Defaults to False.
4245
4346
Returns:
4447
The qualified expression.
4548
"""
4649
dialect = Dialect.get_or_raise(dialect)
4750

48-
alias_sequence = name_sequence("_q_")
51+
alias_sequence = name_sequence("_" if canonicalize else "_q_")
4952

5053
def next_alias_name() -> str:
5154
return normalize_identifiers(alias_sequence(), dialect=dialect).name
@@ -74,6 +77,32 @@ def _qualify(table: exp.Table) -> None:
7477
if isinstance(node, exp.Table) and node.name not in cte_names:
7578
_qualify(node)
7679

80+
canonical_aliases: t.Dict[str, str] = {}
81+
82+
def _set_alias(
83+
expression: exp.Expression,
84+
target_alias: t.Optional[str] = None,
85+
scope: t.Optional[Scope] = None,
86+
normalize: bool = False,
87+
) -> None:
88+
alias = expression.args.get("alias") or exp.TableAlias()
89+
90+
if canonicalize:
91+
new_alias_name = next_alias_name()
92+
canonical_aliases[alias.name or target_alias or ""] = new_alias_name
93+
elif not alias.name:
94+
new_alias_name = target_alias or next_alias_name()
95+
if normalize:
96+
new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name
97+
else:
98+
return
99+
100+
alias.set("this", exp.to_identifier(new_alias_name))
101+
expression.set("alias", alias)
102+
103+
if scope:
104+
scope.rename_source(None, new_alias_name)
105+
77106
for scope in traverse_scope(expression):
78107
for derived_table in scope.derived_tables:
79108
unnested = derived_table.unnest()
@@ -83,78 +112,57 @@ def _qualify(table: exp.Table) -> None:
83112
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
84113
derived_table.this.set("joins", joins)
85114

86-
if not derived_table.args.get("alias"):
87-
alias = next_alias_name()
88-
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias)))
89-
scope.rename_source(None, alias)
90-
91-
pivots = derived_table.args.get("pivots")
92-
if pivots and not pivots[0].alias:
93-
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
115+
_set_alias(derived_table, scope=scope)
116+
if pivot := seq_get(derived_table.args.get("pivots") or [], 0):
117+
_set_alias(pivot)
94118

95119
table_aliases = {}
96120

97121
for name, source in scope.sources.items():
98122
if isinstance(source, exp.Table):
99-
pivots = source.args.get("pivots")
100-
if not source.alias:
101-
# Don't add the pivot's alias to the pivoted table, use the table's name instead
102-
if pivots and pivots[0].alias == name:
103-
name = source.name
104-
105-
# Mutates the source by attaching an alias to it
106-
normalized_alias = normalize_identifiers(
107-
name or source.name or alias_sequence(), dialect=dialect
108-
)
109-
exp.alias_(source, normalized_alias, copy=False, table=True)
110-
111-
table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
112-
source.alias
113-
)
114-
115-
if pivots:
116-
pivot = pivots[0]
117-
if not pivot.alias:
118-
pivot_alias = normalize_identifiers(
119-
source.alias if pivot.unpivot else alias_sequence(),
120-
dialect=dialect,
121-
)
122-
pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
123+
# When the name is empty, it means that we have a non-table source, e.g. a pivoted Cte
124+
is_real_table_source = bool(name)
125+
126+
if pivot := seq_get(source.args.get("pivots") or [], 0):
127+
name = source.name
128+
129+
_set_alias(source, target_alias=name or source.name or None, normalize=True)
130+
131+
source_fqn = ".".join(p.name for p in source.parts)
132+
table_aliases[source_fqn] = exp.to_identifier(source.alias)
133+
134+
if pivot:
135+
target_alias = source.alias if pivot.unpivot else None
136+
_set_alias(pivot, target_alias=target_alias, normalize=True)
123137

124138
# This case corresponds to a pivoted CTE, we don't want to qualify that
125139
if isinstance(scope.sources.get(source.alias_or_name), Scope):
126140
continue
127141

128-
_qualify(source)
142+
if is_real_table_source:
143+
_qualify(source)
129144

130-
if on_qualify:
131-
on_qualify(source)
145+
if on_qualify:
146+
on_qualify(source)
132147
elif isinstance(source, Scope) and source.is_udtf:
133-
udtf = source.expression
134-
table_alias = udtf.args.get("alias") or exp.TableAlias(
135-
this=exp.to_identifier(next_alias_name())
136-
)
137-
udtf.set("alias", table_alias)
138-
139-
if not table_alias.name:
140-
table_alias.set("this", exp.to_identifier(next_alias_name()))
148+
_set_alias(udtf := source.expression)
149+
150+
table_alias = udtf.args["alias"]
151+
141152
if isinstance(udtf, exp.Values) and not table_alias.columns:
142153
column_aliases = [
143154
normalize_identifiers(i, dialect=dialect)
144155
for i in dialect.generate_values_aliases(udtf)
145156
]
146157
table_alias.set("columns", column_aliases)
147-
else:
148-
for node in scope.walk():
149-
if (
150-
isinstance(node, exp.Table)
151-
and not node.alias
152-
and isinstance(node.parent, (exp.From, exp.Join))
153-
):
154-
# Mutates the table by attaching an alias to it
155-
exp.alias_(node, node.name, copy=False, table=True)
158+
159+
for table in scope.tables:
160+
if not table.alias and isinstance(table.parent, (exp.From, exp.Join)):
161+
_set_alias(table, target_alias=table.name)
156162

157163
for column in scope.columns:
164+
table = column.table
165+
158166
if column.db:
159167
table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
160168

@@ -163,5 +171,13 @@ def _qualify(table: exp.Table) -> None:
163171
column.set(p, None)
164172

165173
column.set("table", table_alias.copy())
174+
elif (
175+
canonical_aliases
176+
and table
177+
and (canonical_table := canonical_aliases.get(table, "")) != column.table
178+
):
179+
# Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
180+
column.set("table", exp.to_identifier(canonical_table))
181+
pass
166182

167183
return expression

tests/test_optimizer.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,77 @@ def test_isolate_table_selects(self):
240240
)
241241

242242
def test_qualify_tables(self):
243+
self.assertEqual(
244+
optimizer.qualify_tables.qualify_tables(
245+
parse_one("SELECT * FROM t"),
246+
db="db",
247+
catalog="c",
248+
canonicalize=True,
249+
).sql(),
250+
"SELECT * FROM c.db.t AS _0",
251+
)
252+
253+
self.assertEqual(
254+
optimizer.qualify_tables.qualify_tables(
255+
parse_one("SELECT * FROM t1 JOIN t2 ON t1.id = t2.id"),
256+
db="db",
257+
catalog="c",
258+
canonicalize=True,
259+
).sql(),
260+
"SELECT * FROM c.db.t1 AS _0 JOIN c.db.t2 AS _1 ON _0.id = _1.id",
261+
)
262+
263+
self.assertEqual(
264+
optimizer.qualify_tables.qualify_tables(
265+
parse_one("SELECT * FROM db1.users JOIN db2.users ON db1.users.id = db2.users.id"),
266+
catalog="c",
267+
canonicalize=True,
268+
).sql(),
269+
"SELECT * FROM c.db1.users AS _0 JOIN c.db2.users AS _1 ON _0.id = _1.id",
270+
)
271+
272+
self.assertEqual(
273+
optimizer.qualify_tables.qualify_tables(
274+
parse_one("WITH cte AS (SELECT * FROM t) SELECT * FROM cte"),
275+
db="db",
276+
catalog="c",
277+
canonicalize=True,
278+
).sql(),
279+
"WITH cte AS (SELECT * FROM c.db.t AS _0) SELECT * FROM cte AS _1",
280+
)
281+
282+
self.assertEqual(
283+
optimizer.qualify_tables.qualify_tables(
284+
parse_one("SELECT * FROM (SELECT * FROM t)"),
285+
db="db",
286+
catalog="c",
287+
canonicalize=True,
288+
).sql(),
289+
"SELECT * FROM (SELECT * FROM c.db.t AS _0) AS _1",
290+
)
291+
292+
self.assertEqual(
293+
optimizer.qualify_tables.qualify_tables(
294+
parse_one("SELECT * FROM t1, (SELECT * FROM t2) AS sub, t3"),
295+
db="db",
296+
catalog="c",
297+
canonicalize=True,
298+
).sql(),
299+
"SELECT * FROM c.db.t1 AS _2, (SELECT * FROM c.db.t2 AS _0) AS _1, c.db.t3 AS _3",
300+
)
301+
302+
self.assertEqual(
303+
optimizer.qualify_tables.qualify_tables(
304+
parse_one(
305+
"WITH cte AS (SELECT * FROM t) SELECT * FROM cte PIVOT(SUM(c) FOR v IN ('x', 'y'))"
306+
),
307+
db="db",
308+
catalog="c",
309+
canonicalize=True,
310+
).sql(),
311+
"WITH cte AS (SELECT * FROM c.db.t AS _0) SELECT * FROM cte AS _1 PIVOT(SUM(c) FOR v IN ('x', 'y')) AS _2",
312+
)
313+
243314
self.assertEqual(
244315
optimizer.qualify.qualify(
245316
parse_one("WITH tesT AS (SELECT * FROM t1) SELECT * FROM test", "bigquery"),

0 commit comments

Comments
 (0)