Skip to content

Commit bc061e0

Browse files
committed
Feat(snowflake): improve transpilation of unnested object lookup
1 parent e83b974 commit bc061e0

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from sqlglot.generator import unsupported_args
3333
from sqlglot.helper import flatten, is_float, is_int, seq_get
34+
from sqlglot.optimizer.scope import find_all_in_scope
3435
from sqlglot.tokens import TokenType
3536

3637
if t.TYPE_CHECKING:
@@ -333,6 +334,34 @@ def _json_extract_value_array_sql(
333334
return self.func("TRANSFORM", json_extract, transform_lambda)
334335

335336

337+
def _eliminate_dot_variant_lookup(expression: exp.Expression) -> exp.Expression:
338+
if isinstance(expression, exp.Select):
339+
# This transformation is used to facilitate transpilation of BigQuery `UNNEST` operations
340+
# to Snowflake. It should not affect roundtrip because `Unnest` nodes cannot be produced
341+
# by Snowflake's parser.
342+
#
343+
# Additionally, at the time of writing this, BigQuery is the only dialect that produces a
344+
# `TableAlias` node that only fills `columns` and not `this`, due to `UNNEST_COLUMN_ONLY`.
345+
unnest_aliases = set()
346+
for unnest in find_all_in_scope(expression, exp.Unnest):
347+
unnest_alias = unnest.args.get("alias")
348+
if (
349+
isinstance(unnest_alias, exp.TableAlias)
350+
and not unnest_alias.this
351+
and len(unnest_alias.columns) == 1
352+
):
353+
unnest_aliases.add(unnest_alias.columns[0].name)
354+
355+
if unnest_aliases:
356+
for c in find_all_in_scope(expression, exp.Column):
357+
if c.table in unnest_aliases:
358+
bracket_lhs = c.args["table"]
359+
bracket_rhs = exp.Literal.string(c.name)
360+
c.replace(exp.Bracket(this=bracket_lhs, expressions=[bracket_rhs]))
361+
362+
return expression
363+
364+
336365
class Snowflake(Dialect):
337366
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
338367
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -1089,6 +1118,7 @@ class Generator(generator.Generator):
10891118
transforms.explode_projection_to_unnest(),
10901119
transforms.eliminate_semi_and_anti_joins,
10911120
_transform_generate_date_array,
1121+
_eliminate_dot_variant_lookup,
10921122
]
10931123
),
10941124
exp.SHA: rename_func("SHA1"),

sqlglot/optimizer/scope.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,12 +846,14 @@ def walk_in_scope(expression, bfs=True, prune=None):
846846

847847
if node is expression:
848848
continue
849+
849850
if (
850851
isinstance(node, exp.CTE)
851852
or (
852853
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
853-
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
854+
and _is_derived_table(node)
854855
)
856+
or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
855857
or isinstance(node, exp.UNWRAPPED_QUERIES)
856858
):
857859
crossed_scope_boundary = True

tests/dialects/test_snowflake.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,24 @@ def test_snowflake(self):
318318
"SELECT * FROM xxx, yyy, zzz",
319319
)
320320

321+
self.validate_all(
322+
"SELECT _u['foo'], bar, baz FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('foo', 'x', 'bars', ['y', 'z'], 'bazs', ['w'])])) AS _t0(seq, key, path, index, _u, this), TABLE(FLATTEN(INPUT => _u['bars'])) AS _t1(seq, key, path, index, bar, this), TABLE(FLATTEN(INPUT => _u['bazs'])) AS _t2(seq, key, path, index, baz, this)",
323+
read={
324+
"bigquery": "SELECT _u.foo, bar, baz FROM UNNEST([struct('x' AS foo, ['y', 'z'] AS bars, ['w'] AS bazs)]) AS _u, UNNEST(_u.bars) AS bar, UNNEST(_u.bazs) AS baz",
325+
},
326+
)
327+
self.validate_all(
328+
"SELECT _u, _u['foo'], _u['bar'] FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('foo', 'x', 'bar', 'y')])) AS _t0(seq, key, path, index, _u, this)",
329+
read={
330+
"bigquery": "select _u, _u.foo, _u.bar from unnest([struct('x' as foo, 'y' AS bar)]) as _u",
331+
},
332+
)
333+
self.validate_all(
334+
"SELECT _u['foo'][0].bar FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('foo', [OBJECT_CONSTRUCT('bar', 1)])])) AS _t0(seq, key, path, index, _u, this)",
335+
read={
336+
"bigquery": "select _u.foo[0].bar from unnest([struct([struct(1 as bar)] as foo)]) as _u",
337+
},
338+
)
321339
self.validate_all(
322340
"CREATE TABLE test_table (id NUMERIC NOT NULL AUTOINCREMENT)",
323341
write={

0 commit comments

Comments
 (0)