Skip to content

Commit 0be0b82

Browse files
authored
Add a query provider to the expand function, enabling lazy loading of queries during AST transformations. (#4872)
* add a query provider to the expand function to load queries lazily * change laziness implementation to use callables as source providers to simplify the usage * fix linting issues
1 parent d748e53 commit 0be0b82

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

sqlglot/expressions.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8494,7 +8494,7 @@ def _replace_placeholders(node: Expression, args, **kwargs) -> Expression:
84948494

84958495
def expand(
84968496
expression: Expression,
8497-
sources: t.Dict[str, Query],
8497+
sources: t.Dict[str, Query] | t.Callable[[str], t.Optional[Query]],
84988498
dialect: DialectType = None,
84998499
copy: bool = True,
85008500
) -> Expression:
@@ -8510,22 +8510,36 @@ def expand(
85108510
85118511
Args:
85128512
expression: The expression to expand.
8513-
sources: A dictionary of name to Queries.
8514-
dialect: The dialect of the sources dict.
8513+
sources: A dict of name to query or a callable that provides a query on demand.
8514+
dialect: The dialect of the sources dict or the callable.
85158515
copy: Whether to copy the expression during transformation. Defaults to True.
85168516
85178517
Returns:
85188518
The transformed expression.
85198519
"""
8520-
sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()}
8520+
# Create a query provider based on the sources parameter
8521+
if callable(sources):
8522+
get_source = sources
8523+
else:
8524+
# Pre-normalize table names in sources dictionary for consistent lookups
8525+
normalized_sources = {
8526+
normalize_table_name(k, dialect=dialect): v for k, v in sources.items()
8527+
}
8528+
8529+
def _get_source(name: str) -> t.Optional[Query]:
8530+
return normalized_sources.get(name)
8531+
8532+
get_source = _get_source
85218533

85228534
def _expand(node: Expression):
85238535
if isinstance(node, Table):
85248536
name = normalize_table_name(node, dialect=dialect)
8525-
source = sources.get(name)
8537+
source = get_source(name)
85268538
if source:
8539+
# Create a subquery with the same alias (or table name if no alias)
85278540
subquery = source.subquery(node.alias or name)
85288541
subquery.comments = [f"source: {name}"]
8542+
# Continue expanding within the subquery
85298543
return subquery.transform(_expand, copy=False)
85308544
return node
85318545

tests/test_expressions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import sys
22
import datetime
33
import math
4+
import typing as t
45
import unittest
56

67
from sqlglot import ParseError, alias, exp, parse_one
8+
from sqlglot.expressions import Query, normalize_table_name
79

810

911
class TestExpressions(unittest.TestCase):
@@ -277,6 +279,28 @@ def test_expand(self):
277279
"SELECT * FROM (SELECT 1) AS a /* source: a-b.c */",
278280
)
279281

282+
def test_expand_with_lazy_source_provider(self):
283+
class DynamicSourceProvider:
284+
def __init__(self, dialect: str = "spark"):
285+
self._sources = {normalize_table_name("`a-b`.`c`", dialect): "select 1"}
286+
287+
def get(self, name: str) -> t.Optional[Query]:
288+
query_sql = self._sources.get(name)
289+
290+
if query_sql:
291+
return parse_one(query_sql)
292+
293+
dynamic_source_provider = DynamicSourceProvider()
294+
295+
self.assertEqual(
296+
exp.expand(
297+
parse_one('select * from "a-b"."C" AS a'),
298+
lambda name: dynamic_source_provider.get(name),
299+
dialect="spark",
300+
).sql(),
301+
"SELECT * FROM (SELECT 1) AS a /* source: a-b.c */",
302+
)
303+
280304
def test_replace_placeholders(self):
281305
self.assertEqual(
282306
exp.replace_placeholders(

0 commit comments

Comments
 (0)