Skip to content

Commit 886f85b

Browse files
authored
Fix(optimizer)!: pass dialect to ensure_schema (#5100)
1 parent 222dafd commit 886f85b

File tree

5 files changed

+58
-14
lines changed

5 files changed

+58
-14
lines changed

sqlglot/optimizer/annotate_types.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def annotate_types(
3232
schema: t.Optional[t.Dict | Schema] = None,
3333
annotators: t.Optional[AnnotatorsType] = None,
3434
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
35-
dialect: t.Optional[DialectType] = None,
35+
dialect: DialectType = None,
3636
) -> E:
3737
"""
3838
Infers the types of an expression, annotating its AST accordingly.
@@ -55,9 +55,9 @@ def annotate_types(
5555
The expression annotated with types.
5656
"""
5757

58-
schema = ensure_schema(schema)
58+
schema = ensure_schema(schema, dialect=dialect)
5959

60-
return TypeAnnotator(schema, annotators, coerces_to, dialect=dialect).annotate(expression)
60+
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
6161

6262

6363
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
@@ -182,10 +182,9 @@ def __init__(
182182
annotators: t.Optional[AnnotatorsType] = None,
183183
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
184184
binary_coercions: t.Optional[BinaryCoercions] = None,
185-
dialect: t.Optional[DialectType] = None,
186185
) -> None:
187186
self.schema = schema
188-
self.annotators = annotators or Dialect.get_or_raise(dialect).ANNOTATORS
187+
self.annotators = annotators or Dialect.get_or_raise(schema.dialect).ANNOTATORS
189188
self.coerces_to = coerces_to or self.COERCES_TO
190189
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
191190

sqlglot/optimizer/isolate_table_selects.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
15
from sqlglot import alias, exp
26
from sqlglot.errors import OptimizeError
37
from sqlglot.optimizer.scope import traverse_scope
48
from sqlglot.schema import ensure_schema
59

10+
if t.TYPE_CHECKING:
11+
from sqlglot._typing import E
12+
from sqlglot.schema import Schema
13+
from sqlglot.dialects.dialect import DialectType
614

7-
def isolate_table_selects(expression, schema=None):
8-
schema = ensure_schema(schema)
15+
16+
def isolate_table_selects(
17+
expression: E,
18+
schema: t.Optional[t.Dict | Schema] = None,
19+
dialect: DialectType = None,
20+
) -> E:
21+
schema = ensure_schema(schema, dialect=dialect)
922

1023
for scope in traverse_scope(expression):
1124
if len(scope.selected_sources) == 1:
1225
continue
1326

1427
for _, source in scope.selected_sources.values():
28+
assert source.parent
29+
1530
if (
1631
not isinstance(source, exp.Table)
1732
or not schema.column_names(source)

sqlglot/optimizer/pushdown_projections.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
import typing as t
14
from collections import defaultdict
25

36
from sqlglot import alias, exp
@@ -7,6 +10,11 @@
710
from sqlglot.errors import OptimizeError
811
from sqlglot.helper import seq_get
912

13+
if t.TYPE_CHECKING:
14+
from sqlglot._typing import E
15+
from sqlglot.schema import Schema
16+
from sqlglot.dialects.dialect import DialectType
17+
1018
# Sentinel value that means an outer query selecting ALL columns
1119
SELECT_ALL = object()
1220

@@ -16,7 +24,12 @@ def default_selection(is_agg: bool) -> exp.Alias:
1624
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
1725

1826

19-
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
27+
def pushdown_projections(
28+
expression: E,
29+
schema: t.Optional[t.Dict | Schema] = None,
30+
remove_unused_selections: bool = True,
31+
dialect: DialectType = None,
32+
) -> E:
2033
"""
2134
Rewrite sqlglot AST to remove unused columns projections.
2235
@@ -34,9 +47,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
3447
sqlglot.Expression: optimized expression
3548
"""
3649
# Map of Scope to all columns being selected by outer queries.
37-
schema = ensure_schema(schema)
38-
source_column_alias_count = {}
39-
referenced_columns = defaultdict(set)
50+
schema = ensure_schema(schema, dialect=dialect)
51+
source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {}
52+
referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set)
4053

4154
# We build the scope tree (which is traversed in DFS postorder), then iterate
4255
# over the result in reverse order. This should ensure that the set of selected
@@ -69,12 +82,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
6982
if scope.expression.args.get("by_name"):
7083
referenced_columns[right] = referenced_columns[left]
7184
else:
72-
referenced_columns[right] = [
85+
referenced_columns[right] = {
7386
right.expression.selects[i].alias_or_name
7487
for i, select in enumerate(left.expression.selects)
7588
if SELECT_ALL in parent_selections
7689
or select.alias_or_name in parent_selections
77-
]
90+
}
7891

7992
if isinstance(scope.expression, exp.Select):
8093
if remove_unused_selections:

sqlglot/optimizer/qualify_columns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def qualify_columns(
2323
expand_stars: bool = True,
2424
infer_schema: t.Optional[bool] = None,
2525
allow_partial_qualification: bool = False,
26+
dialect: DialectType = None,
2627
) -> exp.Expression:
2728
"""
2829
Rewrite sqlglot AST to have fully qualified columns.
@@ -50,7 +51,7 @@ def qualify_columns(
5051
Notes:
5152
- Currently only handles a single PIVOT or UNPIVOT operator
5253
"""
53-
schema = ensure_schema(schema)
54+
schema = ensure_schema(schema, dialect=dialect)
5455
annotator = TypeAnnotator(schema)
5556
infer_schema = schema.empty if infer_schema is None else infer_schema
5657
dialect = Dialect.get_or_raise(schema.dialect)

tests/test_optimizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,3 +1552,19 @@ def gen_expr(depth: int) -> exp.Expression:
15521552
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
15531553
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
15541554
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
1555+
1556+
def test_manually_annotate_snowflake(self):
1557+
dialect = "snowflake"
1558+
schema = {
1559+
"SCHEMA": {
1560+
"TBL": {"COL": "INT", "col2": "VARCHAR"},
1561+
}
1562+
}
1563+
example_query = 'SELECT * FROM "SCHEMA"."TBL"'
1564+
1565+
expression = parse_one(example_query, dialect=dialect)
1566+
qual = optimizer.qualify.qualify(expression, schema=schema, dialect=dialect)
1567+
annotated = optimizer.annotate_types.annotate_types(qual, schema=schema, dialect=dialect)
1568+
1569+
self.assertTrue(annotated.selects[0].is_type("INT"))
1570+
self.assertTrue(annotated.selects[1].is_type("VARCHAR"))

0 commit comments

Comments
 (0)