44
55from sqlglot import exp
66from sqlglot .dialects .dialect import Dialect , DialectType
7- from sqlglot .helper import name_sequence
7+ from sqlglot .helper import name_sequence , seq_get
88from sqlglot .optimizer .normalize_identifiers import normalize_identifiers
99from sqlglot .optimizer .scope import Scope , traverse_scope
1010
@@ -18,6 +18,7 @@ def qualify_tables(
1818 catalog : t .Optional [str | exp .Identifier ] = None ,
1919 on_qualify : t .Optional [t .Callable [[exp .Expression ], None ]] = None ,
2020 dialect : DialectType = None ,
21+ canonicalize : bool = False ,
2122) -> E :
2223 """
2324 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -39,13 +40,15 @@ def qualify_tables(
3940 catalog: Catalog name
4041 on_qualify: Callback after a table has been qualified.
4142 dialect: The dialect to parse catalog and schema into.
43+ canonicalize: Whether to use canonical aliases (_0, _1, ...) for all sources
44+ instead of preserving table names. Defaults to False.
4245
4346 Returns:
4447 The qualified expression.
4548 """
4649 dialect = Dialect .get_or_raise (dialect )
4750
48- alias_sequence = name_sequence ("_q_" )
51+ alias_sequence = name_sequence ("_" if canonicalize else " _q_" )
4952
5053 def next_alias_name () -> str :
5154 return normalize_identifiers (alias_sequence (), dialect = dialect ).name
@@ -74,6 +77,32 @@ def _qualify(table: exp.Table) -> None:
7477 if isinstance (node , exp .Table ) and node .name not in cte_names :
7578 _qualify (node )
7679
80+ canonical_aliases : t .Dict [str , str ] = {}
81+
82+ def _set_alias (
83+ expression : exp .Expression ,
84+ target_alias : t .Optional [str ] = None ,
85+ scope : t .Optional [Scope ] = None ,
86+ normalize : bool = False ,
87+ ) -> None :
88+ alias = expression .args .get ("alias" ) or exp .TableAlias ()
89+
90+ if canonicalize :
91+ new_alias_name = next_alias_name ()
92+ canonical_aliases [alias .name or target_alias or "" ] = new_alias_name
93+ elif not alias .name :
94+ new_alias_name = target_alias or next_alias_name ()
95+ if normalize :
96+ new_alias_name = normalize_identifiers (new_alias_name , dialect = dialect ).name
97+ else :
98+ return
99+
100+ alias .set ("this" , exp .to_identifier (new_alias_name ))
101+ expression .set ("alias" , alias )
102+
103+ if scope :
104+ scope .rename_source (None , new_alias_name )
105+
77106 for scope in traverse_scope (expression ):
78107 for derived_table in scope .derived_tables :
79108 unnested = derived_table .unnest ()
@@ -83,78 +112,57 @@ def _qualify(table: exp.Table) -> None:
83112 derived_table .this .replace (exp .select ("*" ).from_ (unnested .copy (), copy = False ))
84113 derived_table .this .set ("joins" , joins )
85114
86- if not derived_table .args .get ("alias" ):
87- alias = next_alias_name ()
88- derived_table .set ("alias" , exp .TableAlias (this = exp .to_identifier (alias )))
89- scope .rename_source (None , alias )
90-
91- pivots = derived_table .args .get ("pivots" )
92- if pivots and not pivots [0 ].alias :
93- pivots [0 ].set ("alias" , exp .TableAlias (this = exp .to_identifier (next_alias_name ())))
115+ _set_alias (derived_table , scope = scope )
116+ if pivot := seq_get (derived_table .args .get ("pivots" ) or [], 0 ):
117+ _set_alias (pivot )
94118
95119 table_aliases = {}
96120
97121 for name , source in scope .sources .items ():
98122 if isinstance (source , exp .Table ):
99- pivots = source .args .get ("pivots" )
100- if not source .alias :
101- # Don't add the pivot's alias to the pivoted table, use the table's name instead
102- if pivots and pivots [0 ].alias == name :
103- name = source .name
104-
105- # Mutates the source by attaching an alias to it
106- normalized_alias = normalize_identifiers (
107- name or source .name or alias_sequence (), dialect = dialect
108- )
109- exp .alias_ (source , normalized_alias , copy = False , table = True )
110-
111- table_aliases ["." .join (p .name for p in source .parts )] = exp .to_identifier (
112- source .alias
113- )
114-
115- if pivots :
116- pivot = pivots [0 ]
117- if not pivot .alias :
118- pivot_alias = normalize_identifiers (
119- source .alias if pivot .unpivot else alias_sequence (),
120- dialect = dialect ,
121- )
122- pivot .set ("alias" , exp .TableAlias (this = exp .to_identifier (pivot_alias )))
123+ # When the name is empty, it means that we have a non-table source, e.g. a pivoted Cte
124+ is_real_table_source = bool (name )
125+
126+ if pivot := seq_get (source .args .get ("pivots" ) or [], 0 ):
127+ name = source .name
128+
129+ _set_alias (source , target_alias = name or source .name or None , normalize = True )
130+
131+ source_fqn = "." .join (p .name for p in source .parts )
132+ table_aliases [source_fqn ] = exp .to_identifier (source .alias )
133+
134+ if pivot :
135+ target_alias = source .alias if pivot .unpivot else None
136+ _set_alias (pivot , target_alias = target_alias , normalize = True )
123137
124138 # This case corresponds to a pivoted CTE, we don't want to qualify that
125139 if isinstance (scope .sources .get (source .alias_or_name ), Scope ):
126140 continue
127141
128- _qualify (source )
142+ if is_real_table_source :
143+ _qualify (source )
129144
130- if on_qualify :
131- on_qualify (source )
145+ if on_qualify :
146+ on_qualify (source )
132147 elif isinstance (source , Scope ) and source .is_udtf :
133- udtf = source .expression
134- table_alias = udtf .args .get ("alias" ) or exp .TableAlias (
135- this = exp .to_identifier (next_alias_name ())
136- )
137- udtf .set ("alias" , table_alias )
138-
139- if not table_alias .name :
140- table_alias .set ("this" , exp .to_identifier (next_alias_name ()))
148+ _set_alias (udtf := source .expression )
149+
150+ table_alias = udtf .args ["alias" ]
151+
141152 if isinstance (udtf , exp .Values ) and not table_alias .columns :
142153 column_aliases = [
143154 normalize_identifiers (i , dialect = dialect )
144155 for i in dialect .generate_values_aliases (udtf )
145156 ]
146157 table_alias .set ("columns" , column_aliases )
147- else :
148- for node in scope .walk ():
149- if (
150- isinstance (node , exp .Table )
151- and not node .alias
152- and isinstance (node .parent , (exp .From , exp .Join ))
153- ):
154- # Mutates the table by attaching an alias to it
155- exp .alias_ (node , node .name , copy = False , table = True )
158+
159+ for table in scope .tables :
160+ if not table .alias and isinstance (table .parent , (exp .From , exp .Join )):
161+ _set_alias (table , target_alias = table .name )
156162
157163 for column in scope .columns :
164+ table = column .table
165+
158166 if column .db :
159167 table_alias = table_aliases .get ("." .join (p .name for p in column .parts [0 :- 1 ]))
160168
@@ -163,5 +171,13 @@ def _qualify(table: exp.Table) -> None:
163171 column .set (p , None )
164172
165173 column .set ("table" , table_alias .copy ())
174+ elif (
175+ canonical_aliases
176+ and table
177+ and (canonical_table := canonical_aliases .get (table , "" )) != column .table
178+ ):
179+ # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
180+ column .set ("table" , exp .to_identifier (canonical_table ))
181+ pass
166182
167183 return expression
0 commit comments