Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sqlglot/optimizer/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def qualify(
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
canonicalize_table_aliases: bool = False,
on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
) -> exp.Expression:
"""
Expand Down Expand Up @@ -62,6 +63,8 @@ def qualify(
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.
canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
instead of preserving table names.
on_qualify: Callback after a table has been qualified.

Returns:
Expand All @@ -81,6 +84,7 @@ def qualify(
catalog=catalog,
dialect=dialect,
on_qualify=on_qualify,
canonicalize_table_aliases=canonicalize_table_aliases,
)

if isolate_tables:
Expand Down
122 changes: 74 additions & 48 deletions sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.helper import name_sequence
from sqlglot.helper import name_sequence, seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.scope import Scope, traverse_scope

Expand All @@ -18,6 +18,7 @@ def qualify_tables(
catalog: t.Optional[str | exp.Identifier] = None,
on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
dialect: DialectType = None,
canonicalize_table_aliases: bool = False,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
Expand All @@ -39,16 +40,14 @@ def qualify_tables(
catalog: Catalog name
on_qualify: Callback after a table has been qualified.
dialect: The dialect to parse catalog and schema into.
canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
instead of preserving table names. Defaults to False.

Returns:
The qualified expression.
"""
dialect = Dialect.get_or_raise(dialect)

alias_sequence = name_sequence("_q_")

def next_alias_name() -> str:
return normalize_identifiers(alias_sequence(), dialect=dialect).name
next_alias_name = name_sequence("_")

if db := db or None:
db = exp.parse_identifier(db, dialect=dialect)
Expand All @@ -74,7 +73,38 @@ def _qualify(table: exp.Table) -> None:
if isinstance(node, exp.Table) and node.name not in cte_names:
_qualify(node)

def _set_alias(
expression: exp.Expression,
canonical_aliases: t.Dict[str, str],
target_alias: t.Optional[str] = None,
scope: t.Optional[Scope] = None,
normalize: bool = False,
) -> None:
alias = expression.args.get("alias") or exp.TableAlias()

if canonicalize_table_aliases:
new_alias_name = next_alias_name()
canonical_aliases[alias.name or target_alias or ""] = new_alias_name
elif not alias.name:
new_alias_name = target_alias or next_alias_name()
if normalize and target_alias:
new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name
else:
return

# Auto-generated aliases (_1, _2, ...) are quoted in order to be valid across all dialects
quoted = True if canonicalize_table_aliases or not target_alias else None

alias.set("this", exp.to_identifier(new_alias_name, quoted=quoted))
expression.set("alias", alias)

if scope:
scope.rename_source(None, new_alias_name)

for scope in traverse_scope(expression):
local_columns = scope.local_columns
canonical_aliases: t.Dict[str, str] = {}
Comment on lines +105 to +106
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auto-generated alias mappings are constructed per-scope to respect lexical scoping.


for query in scope.subqueries:
subquery = query.parent
if isinstance(subquery, exp.Subquery):
Expand All @@ -88,61 +118,48 @@ def _qualify(table: exp.Table) -> None:
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
derived_table.this.set("joins", joins)

if not derived_table.args.get("alias"):
alias = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias)))
scope.rename_source(None, alias)

pivots = derived_table.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
_set_alias(derived_table, canonical_aliases, scope=scope)
if pivot := seq_get(derived_table.args.get("pivots") or [], 0):
_set_alias(pivot, canonical_aliases)

table_aliases = {}

for name, source in scope.sources.items():
if isinstance(source, exp.Table):
pivots = source.args.get("pivots")
if not source.alias:
# Don't add the pivot's alias to the pivoted table, use the table's name instead
if pivots and pivots[0].alias == name:
name = source.name

# Mutates the source by attaching an alias to it
normalized_alias = normalize_identifiers(
name or source.name or alias_sequence(), dialect=dialect
)
exp.alias_(source, normalized_alias, copy=False, table=True)

table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
source.alias
# When the name is empty, it means that we have a non-table source, e.g. a pivoted cte
is_real_table_source = bool(name)

if pivot := seq_get(source.args.get("pivots") or [], 0):
name = source.name

_set_alias(
source,
canonical_aliases,
target_alias=name or source.name or None,
normalize=True,
)

if pivots:
pivot = pivots[0]
if not pivot.alias:
pivot_alias = normalize_identifiers(
source.alias if pivot.unpivot else alias_sequence(),
dialect=dialect,
)
pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
source_fqn = ".".join(p.name for p in source.parts)
table_aliases[source_fqn] = source.args["alias"].this.copy()

if pivot:
target_alias = source.alias if pivot.unpivot else None
_set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True)

# This case corresponds to a pivoted CTE, we don't want to qualify that
if isinstance(scope.sources.get(source.alias_or_name), Scope):
continue

_qualify(source)
if is_real_table_source:
_qualify(source)

if on_qualify:
on_qualify(source)
if on_qualify:
on_qualify(source)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(
this=exp.to_identifier(next_alias_name())
)
udtf.set("alias", table_alias)
_set_alias(udtf := source.expression, canonical_aliases)

table_alias = udtf.args["alias"]

if not table_alias.name:
table_alias.set("this", exp.to_identifier(next_alias_name()))
if isinstance(udtf, exp.Values) and not table_alias.columns:
column_aliases = [
normalize_identifiers(i, dialect=dialect)
Expand All @@ -152,9 +169,11 @@ def _qualify(table: exp.Table) -> None:

for table in scope.tables:
if not table.alias and isinstance(table.parent, (exp.From, exp.Join)):
exp.alias_(table, table.name, copy=False, table=True)
_set_alias(table, canonical_aliases, target_alias=table.name)

for column in local_columns:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this change vs how we previously iterated over scope.columns. If we visited external columns in this loop, we would mutate their source too early, without having had the chance to alias outer scopes' sources.

Before making this change, the test case I added in qualify_tables.sql failed due to having columns qualified with "".

table = column.table

for column in scope.columns:
if column.db:
table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))

Expand All @@ -163,5 +182,12 @@ def _qualify(table: exp.Table) -> None:
column.set(p, None)

column.set("table", table_alias.copy())
elif (
canonical_aliases
and table
and (canonical_table := canonical_aliases.get(table, "")) != column.table
):
# Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
column.set("table", exp.to_identifier(canonical_table, quoted=True))

return expression
21 changes: 17 additions & 4 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def clear_cache(self):
self._selected_sources = None
self._columns = None
self._external_columns = None
self._local_columns = None
self._join_hints = None
self._pivots = None
self._references = None
Expand Down Expand Up @@ -372,8 +373,7 @@ def external_columns(self):
Columns that appear to reference sources in outer scopes.

Returns:
list[exp.Column]: Column instances that don't reference
sources in the current scope.
list[exp.Column]: Column instances that don't reference sources in the current scope.
"""
if self._external_columns is None:
if isinstance(self.expression, exp.SetOperation):
Expand All @@ -383,12 +383,25 @@ def external_columns(self):
self._external_columns = [
c
for c in self.columns
if c.table not in self.selected_sources
and c.table not in self.semi_or_anti_join_tables
if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
]

return self._external_columns

@property
def local_columns(self):
"""
Columns in this scope that are not external.

Returns:
list[exp.Column]: Column instances that reference sources in the current scope.
"""
if self._local_columns is None:
external_columns = set(self.external_columns)
self._local_columns = [c for c in self.columns if c not in external_columns]

return self._local_columns

@property
def unqualified_columns(self):
"""
Expand Down
Loading