Skip to content

Commit 80466f1

Browse files
authored
fix: expand util - align normalization behaviour with lazy and non-lazy source providers. (#4874)
* Fix sources usage * Format
1 parent 0be0b82 commit 80466f1

File tree

2 files changed

+10
-31
lines changed

2 files changed

+10
-31
lines changed

sqlglot/expressions.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313
from __future__ import annotations
14+
1415
import datetime
1516
import math
1617
import numbers
@@ -37,6 +38,7 @@
3738

3839
if t.TYPE_CHECKING:
3940
from typing_extensions import Self
41+
4042
from sqlglot._typing import E, Lit
4143
from sqlglot.dialects.dialect import DialectType
4244

@@ -8494,7 +8496,7 @@ def _replace_placeholders(node: Expression, args, **kwargs) -> Expression:
84948496

84958497
def expand(
84968498
expression: Expression,
8497-
sources: t.Dict[str, Query] | t.Callable[[str], t.Optional[Query]],
8499+
sources: t.Dict[str, Query | t.Callable[[], Query]],
84988500
dialect: DialectType = None,
84998501
copy: bool = True,
85008502
) -> Expression:
@@ -8517,27 +8519,17 @@ def expand(
85178519
Returns:
85188520
The transformed expression.
85198521
"""
8522+
normalized_sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()}
85208523
# Create a query provider based on the sources parameter
8521-
if callable(sources):
8522-
get_source = sources
8523-
else:
8524-
# Pre-normalize table names in sources dictionary for consistent lookups
8525-
normalized_sources = {
8526-
normalize_table_name(k, dialect=dialect): v for k, v in sources.items()
8527-
}
8528-
8529-
def _get_source(name: str) -> t.Optional[Query]:
8530-
return normalized_sources.get(name)
8531-
8532-
get_source = _get_source
85338524

85348525
def _expand(node: Expression):
85358526
if isinstance(node, Table):
85368527
name = normalize_table_name(node, dialect=dialect)
8537-
source = get_source(name)
8528+
source = normalized_sources.get(name)
85388529
if source:
85398530
# Create a subquery with the same alias (or table name if no alias)
8540-
subquery = source.subquery(node.alias or name)
8531+
parsed_source = source() if callable(source) else source
8532+
subquery = parsed_source.subquery(node.alias or name)
85418533
subquery.comments = [f"source: {name}"]
85428534
# Continue expanding within the subquery
85438535
return subquery.transform(_expand, copy=False)

tests/test_expressions.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import sys
21
import datetime
32
import math
4-
import typing as t
3+
import sys
54
import unittest
65

76
from sqlglot import ParseError, alias, exp, parse_one
8-
from sqlglot.expressions import Query, normalize_table_name
97

108

119
class TestExpressions(unittest.TestCase):
@@ -280,22 +278,10 @@ def test_expand(self):
280278
)
281279

282280
def test_expand_with_lazy_source_provider(self):
283-
class DynamicSourceProvider:
284-
def __init__(self, dialect: str = "spark"):
285-
self._sources = {normalize_table_name("`a-b`.`c`", dialect): "select 1"}
286-
287-
def get(self, name: str) -> t.Optional[Query]:
288-
query_sql = self._sources.get(name)
289-
290-
if query_sql:
291-
return parse_one(query_sql)
292-
293-
dynamic_source_provider = DynamicSourceProvider()
294-
295281
self.assertEqual(
296282
exp.expand(
297283
parse_one('select * from "a-b"."C" AS a'),
298-
lambda name: dynamic_source_provider.get(name),
284+
{"`a-b`.c": lambda: parse_one("select 1", dialect="spark")},
299285
dialect="spark",
300286
).sql(),
301287
"SELECT * FROM (SELECT 1) AS a /* source: a-b.c */",
@@ -862,6 +848,7 @@ def test_properties_from_dict(self):
862848

863849
def test_convert(self):
864850
from collections import namedtuple
851+
865852
import pytz
866853

867854
PointTuple = namedtuple("Point", ["x", "y"])

0 commit comments

Comments
 (0)