-
Notifications
You must be signed in to change notification settings - Fork 1k
Feat(optimizer)!: canonicalize table aliases #6369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dbbc375
7bb18a4
fe061e5
42a4b69
d1ce89b
ebb4aa2
7fc2c6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
|
|
||
| from sqlglot import exp | ||
| from sqlglot.dialects.dialect import Dialect, DialectType | ||
| from sqlglot.helper import name_sequence | ||
| from sqlglot.helper import name_sequence, seq_get | ||
| from sqlglot.optimizer.normalize_identifiers import normalize_identifiers | ||
| from sqlglot.optimizer.scope import Scope, traverse_scope | ||
|
|
||
|
|
@@ -18,6 +18,7 @@ def qualify_tables( | |
| catalog: t.Optional[str | exp.Identifier] = None, | ||
| on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None, | ||
| dialect: DialectType = None, | ||
| canonicalize_table_aliases: bool = False, | ||
| ) -> E: | ||
| """ | ||
| Rewrite sqlglot AST to have fully qualified tables. Join constructs such as | ||
|
|
@@ -39,16 +40,14 @@ def qualify_tables( | |
| catalog: Catalog name | ||
| on_qualify: Callback after a table has been qualified. | ||
| dialect: The dialect to parse catalog and schema into. | ||
| canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources | ||
| instead of preserving table names. Defaults to False. | ||
|
|
||
| Returns: | ||
| The qualified expression. | ||
| """ | ||
| dialect = Dialect.get_or_raise(dialect) | ||
|
|
||
| alias_sequence = name_sequence("_q_") | ||
|
|
||
| def next_alias_name() -> str: | ||
| return normalize_identifiers(alias_sequence(), dialect=dialect).name | ||
| next_alias_name = name_sequence("_") | ||
|
|
||
| if db := db or None: | ||
| db = exp.parse_identifier(db, dialect=dialect) | ||
|
|
@@ -74,7 +73,38 @@ def _qualify(table: exp.Table) -> None: | |
| if isinstance(node, exp.Table) and node.name not in cte_names: | ||
| _qualify(node) | ||
|
|
||
| def _set_alias( | ||
| expression: exp.Expression, | ||
| canonical_aliases: t.Dict[str, str], | ||
| target_alias: t.Optional[str] = None, | ||
| scope: t.Optional[Scope] = None, | ||
| normalize: bool = False, | ||
| ) -> None: | ||
| alias = expression.args.get("alias") or exp.TableAlias() | ||
|
|
||
| if canonicalize_table_aliases: | ||
| new_alias_name = next_alias_name() | ||
| canonical_aliases[alias.name or target_alias or ""] = new_alias_name | ||
| elif not alias.name: | ||
| new_alias_name = target_alias or next_alias_name() | ||
| if normalize and target_alias: | ||
| new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name | ||
tobymao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| return | ||
|
|
||
| # Auto-generated aliases (_1, _2, ...) are quoted in order to be valid across all dialects | ||
| quoted = True if canonicalize_table_aliases or not target_alias else None | ||
|
|
||
| alias.set("this", exp.to_identifier(new_alias_name, quoted=quoted)) | ||
| expression.set("alias", alias) | ||
|
|
||
| if scope: | ||
| scope.rename_source(None, new_alias_name) | ||
|
|
||
| for scope in traverse_scope(expression): | ||
| local_columns = scope.local_columns | ||
| canonical_aliases: t.Dict[str, str] = {} | ||
|
Comment on lines
+105
to
+106
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Auto-generated alias mappings are constructed per-scope to respect lexical scoping. |
||
|
|
||
| for query in scope.subqueries: | ||
| subquery = query.parent | ||
| if isinstance(subquery, exp.Subquery): | ||
|
|
@@ -88,61 +118,48 @@ def _qualify(table: exp.Table) -> None: | |
| derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) | ||
| derived_table.this.set("joins", joins) | ||
|
|
||
| if not derived_table.args.get("alias"): | ||
| alias = next_alias_name() | ||
| derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias))) | ||
| scope.rename_source(None, alias) | ||
|
|
||
| pivots = derived_table.args.get("pivots") | ||
| if pivots and not pivots[0].alias: | ||
| pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) | ||
| _set_alias(derived_table, canonical_aliases, scope=scope) | ||
| if pivot := seq_get(derived_table.args.get("pivots") or [], 0): | ||
| _set_alias(pivot, canonical_aliases) | ||
|
|
||
| table_aliases = {} | ||
|
|
||
| for name, source in scope.sources.items(): | ||
| if isinstance(source, exp.Table): | ||
| pivots = source.args.get("pivots") | ||
| if not source.alias: | ||
| # Don't add the pivot's alias to the pivoted table, use the table's name instead | ||
| if pivots and pivots[0].alias == name: | ||
| name = source.name | ||
|
|
||
| # Mutates the source by attaching an alias to it | ||
| normalized_alias = normalize_identifiers( | ||
| name or source.name or alias_sequence(), dialect=dialect | ||
| ) | ||
| exp.alias_(source, normalized_alias, copy=False, table=True) | ||
|
|
||
| table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( | ||
| source.alias | ||
| # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte | ||
| is_real_table_source = bool(name) | ||
|
|
||
| if pivot := seq_get(source.args.get("pivots") or [], 0): | ||
| name = source.name | ||
|
|
||
| _set_alias( | ||
| source, | ||
| canonical_aliases, | ||
| target_alias=name or source.name or None, | ||
| normalize=True, | ||
| ) | ||
|
|
||
| if pivots: | ||
| pivot = pivots[0] | ||
| if not pivot.alias: | ||
| pivot_alias = normalize_identifiers( | ||
| source.alias if pivot.unpivot else alias_sequence(), | ||
| dialect=dialect, | ||
| ) | ||
| pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias))) | ||
| source_fqn = ".".join(p.name for p in source.parts) | ||
| table_aliases[source_fqn] = source.args["alias"].this.copy() | ||
|
|
||
| if pivot: | ||
| target_alias = source.alias if pivot.unpivot else None | ||
| _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True) | ||
|
|
||
| # This case corresponds to a pivoted CTE, we don't want to qualify that | ||
| if isinstance(scope.sources.get(source.alias_or_name), Scope): | ||
| continue | ||
|
|
||
| _qualify(source) | ||
| if is_real_table_source: | ||
| _qualify(source) | ||
|
|
||
| if on_qualify: | ||
| on_qualify(source) | ||
| if on_qualify: | ||
| on_qualify(source) | ||
| elif isinstance(source, Scope) and source.is_udtf: | ||
| udtf = source.expression | ||
| table_alias = udtf.args.get("alias") or exp.TableAlias( | ||
| this=exp.to_identifier(next_alias_name()) | ||
| ) | ||
| udtf.set("alias", table_alias) | ||
| _set_alias(udtf := source.expression, canonical_aliases) | ||
|
|
||
| table_alias = udtf.args["alias"] | ||
|
|
||
| if not table_alias.name: | ||
| table_alias.set("this", exp.to_identifier(next_alias_name())) | ||
| if isinstance(udtf, exp.Values) and not table_alias.columns: | ||
| column_aliases = [ | ||
| normalize_identifiers(i, dialect=dialect) | ||
|
|
@@ -152,9 +169,11 @@ def _qualify(table: exp.Table) -> None: | |
|
|
||
| for table in scope.tables: | ||
| if not table.alias and isinstance(table.parent, (exp.From, exp.Join)): | ||
| exp.alias_(table, table.name, copy=False, table=True) | ||
| _set_alias(table, canonical_aliases, target_alias=table.name) | ||
|
|
||
| for column in local_columns: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this change vs how we previously iterated over Before making this change, the test case I added in |
||
| table = column.table | ||
|
|
||
| for column in scope.columns: | ||
| if column.db: | ||
| table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) | ||
|
|
||
|
|
@@ -163,5 +182,12 @@ def _qualify(table: exp.Table) -> None: | |
| column.set(p, None) | ||
|
|
||
| column.set("table", table_alias.copy()) | ||
| elif ( | ||
| canonical_aliases | ||
| and table | ||
| and (canonical_table := canonical_aliases.get(table, "")) != column.table | ||
| ): | ||
| # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0 | ||
| column.set("table", exp.to_identifier(canonical_table, quoted=True)) | ||
|
|
||
| return expression | ||
Uh oh!
There was an error while loading. Please reload this page.