Skip to content

Commit fc5800d

Browse files
feat(snowflake)!: support Snowflake to DuckDB transpilation of ZIPF (#6618)
* ZIP transpilation for Snowflake to DuckDB * updated tests * changed s and n to force constnatns * improved parameter checking * tweaked to fix random issue * templated the ZIPF transpilation * templated the ZIPF transpilation
1 parent 302fda0 commit fc5800d

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,67 @@ def randstr_sql(self: DuckDB.Generator, expression: exp.Randstr) -> str:
15811581
)
15821582
return f"({self.sql(query)})"
15831583

1584+
# Template for ZIPF transpilation - placeholders get replaced with actual parameters
1585+
ZIPF_TEMPLATE: t.ClassVar[exp.Expression] = exp.maybe_parse(
1586+
"""
1587+
WITH rand AS (SELECT :random_expr AS r),
1588+
weights AS (
1589+
SELECT i, 1.0 / POWER(i, :s) AS w
1590+
FROM RANGE(1, :n + 1) AS t(i)
1591+
),
1592+
cdf AS (
1593+
SELECT i, SUM(w) OVER (ORDER BY i) / SUM(w) OVER () AS p
1594+
FROM weights
1595+
)
1596+
SELECT MIN(i)
1597+
FROM cdf
1598+
WHERE p >= (SELECT r FROM rand)
1599+
"""
1600+
)
1601+
1602+
def zipf_sql(self: DuckDB.Generator, expression: exp.Zipf) -> str:
1603+
"""
1604+
Transpile Snowflake's ZIPF to DuckDB using CDF-based inverse sampling.
1605+
Uses a pre-parsed template with placeholders replaced by expression nodes.
1606+
"""
1607+
s = expression.this
1608+
n = expression.args.get("elementcount")
1609+
gen = expression.args.get("gen")
1610+
1611+
random_expr: exp.Expression
1612+
if isinstance(gen, exp.Rand):
1613+
# Use RANDOM() for non-deterministic output
1614+
random_expr = exp.Rand()
1615+
elif gen:
1616+
# (ABS(HASH(seed)) % 1000000) / 1000000.0
1617+
random_expr = exp.Div(
1618+
this=exp.Paren(
1619+
this=exp.Mod(
1620+
this=exp.Abs(this=exp.Anonymous(this="HASH", expressions=[gen.copy()])),
1621+
expression=exp.Literal.number(1000000),
1622+
)
1623+
),
1624+
expression=exp.Literal.number(1000000.0),
1625+
)
1626+
else:
1627+
random_expr = exp.Rand()
1628+
1629+
# s, n are required args per Zipf.arg_types
1630+
assert s is not None and n is not None
1631+
replacements: dict[str, exp.Expression] = {
1632+
"s": s,
1633+
"n": n,
1634+
"random_expr": random_expr,
1635+
}
1636+
1637+
def replace_placeholder(node: exp.Expression) -> exp.Expression:
1638+
if isinstance(node, exp.Placeholder) and node.name in replacements:
1639+
return replacements[node.name].copy()
1640+
return node
1641+
1642+
query = self.ZIPF_TEMPLATE.copy().transform(replace_placeholder)
1643+
return f"({self.sql(query)})"
1644+
15841645
def tobinary_sql(self: DuckDB.Generator, expression: exp.ToBinary) -> str:
15851646
"""
15861647
TO_BINARY(value, format) transpilation if the return type is BINARY:

tests/dialects/test_snowflake.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,22 @@ def test_snowflake(self):
279279
},
280280
)
281281

282-
self.validate_identity("SELECT ZIPF(1, 10, RANDOM())")
283-
self.validate_identity("SELECT ZIPF(2, 100, 1234)")
282+
self.validate_all(
283+
"SELECT ZIPF(1, 10, 1234)",
284+
write={
285+
"duckdb": "SELECT (WITH rand AS (SELECT (ABS(HASH(1234)) % 1000000) / 1000000.0 AS r), weights AS (SELECT i, 1.0 / POWER(i, 1) AS w FROM RANGE(1, 10 + 1) AS t(i)), cdf AS (SELECT i, SUM(w) OVER (ORDER BY i NULLS FIRST) / SUM(w) OVER () AS p FROM weights) SELECT MIN(i) FROM cdf WHERE p >= (SELECT r FROM rand))",
286+
"snowflake": "SELECT ZIPF(1, 10, 1234)",
287+
},
288+
)
289+
290+
self.validate_all(
291+
"SELECT ZIPF(2, 100, RANDOM())",
292+
write={
293+
"duckdb": "SELECT (WITH rand AS (SELECT RANDOM() AS r), weights AS (SELECT i, 1.0 / POWER(i, 2) AS w FROM RANGE(1, 100 + 1) AS t(i)), cdf AS (SELECT i, SUM(w) OVER (ORDER BY i NULLS FIRST) / SUM(w) OVER () AS p FROM weights) SELECT MIN(i) FROM cdf WHERE p >= (SELECT r FROM rand))",
294+
"snowflake": "SELECT ZIPF(2, 100, RANDOM())",
295+
},
296+
)
297+
284298
self.validate_identity("SELECT GROUPING_ID(a, b) AS g_id FROM x GROUP BY ROLLUP (a, b)")
285299
self.validate_identity("PARSE_URL('https://example.com/path')")
286300
self.validate_identity("PARSE_URL('https://example.com/path', 1)")

0 commit comments

Comments
 (0)