Skip to content

Commit 7f27c44

Browse files
support transpilation of TRY_TO_BINARY from Snowflake to DuckDB
1 parent 9382ebd commit 7f27c44

File tree

3 files changed

+53
-25
lines changed

3 files changed

+53
-25
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,29 +1583,41 @@ def randstr_sql(self: DuckDB.Generator, expression: exp.Randstr) -> str:
15831583

15841584
def tobinary_sql(self: DuckDB.Generator, expression: exp.ToBinary) -> str:
15851585
"""
1586-
TO_BINARY(value, format) transpilation if the return type is BINARY:
1586+
TO_BINARY and TRY_TO_BINARY transpilation:
15871587
- 'HEX': TO_BINARY('48454C50', 'HEX') → UNHEX('48454C50')
15881588
- 'UTF-8': TO_BINARY('TEST', 'UTF-8') → ENCODE('TEST')
15891589
- 'BASE64': TO_BINARY('SEVMUA==', 'BASE64') → FROM_BASE64('SEVMUA==')
15901590
1591-
format can be 'HEX', 'UTF-8' or 'BASE64'
1592-
return type can be either VARCHAR or BINARY
1591+
For TRY_TO_BINARY (safe=True), wrap with TRY():
1592+
- 'HEX': TRY_TO_BINARY('invalid', 'HEX') → TRY(UNHEX('invalid'))
15931593
"""
15941594
value = expression.this
15951595
format_arg = expression.args.get("format")
1596+
is_safe = expression.args.get("safe", False)
15961597

15971598
fmt = "HEX"
15981599
if format_arg:
15991600
fmt = format_arg.name.upper()
16001601

16011602
if expression.is_type(exp.DataType.Type.BINARY):
16021603
if fmt == "UTF-8":
1603-
return self.func("ENCODE", value)
1604-
if fmt == "BASE64":
1605-
return self.func("FROM_BASE64", value)
1604+
result = self.func("ENCODE", value)
1605+
elif fmt == "BASE64":
1606+
result = self.func("FROM_BASE64", value)
1607+
elif fmt == "HEX":
1608+
result = self.func("UNHEX", value)
1609+
else:
1610+
if is_safe:
1611+
return self.sql(exp.null())
1612+
else:
1613+
self.unsupported(f"format {fmt} is not supported")
1614+
result = self.func("TO_BINARY", value)
1615+
1616+
# Wrap with TRY() for TRY_TO_BINARY
1617+
if is_safe:
1618+
result = self.func("TRY", result)
16061619

1607-
# Hex
1608-
return self.func("UNHEX", value)
1620+
return result
16091621

16101622
# Fallback, which needs to be updated if want to support transpilation from other dialects than Snowflake
16111623
return self.func("TO_BINARY", value)

tests/dialects/test_duckdb.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -724,23 +724,6 @@ def test_duckdb(self):
724724
"duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)",
725725
},
726726
)
727-
728-
expr = self.parse_one("TO_BINARY('48454C50', 'HEX')", dialect="snowflake")
729-
annotated = annotate_types(expr, dialect="snowflake")
730-
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
731-
732-
expr = self.parse_one("TO_BINARY('48454C50')", dialect="snowflake")
733-
annotated = annotate_types(expr, dialect="snowflake")
734-
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
735-
736-
expr = self.parse_one("TO_BINARY('TEST', 'UTF-8')", dialect="snowflake")
737-
annotated = annotate_types(expr, dialect="snowflake")
738-
self.assertEqual(annotated.sql("duckdb"), "ENCODE('TEST')")
739-
740-
expr = self.parse_one("TO_BINARY('SEVMUA==', 'BASE64')", dialect="snowflake")
741-
annotated = annotate_types(expr, dialect="snowflake")
742-
self.assertEqual(annotated.sql("duckdb"), "FROM_BASE64('SEVMUA==')")
743-
744727
self.validate_all(
745728
"[0, 1, 2]",
746729
read={

tests/dialects/test_snowflake.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4253,6 +4253,39 @@ def test_round(self):
42534253
},
42544254
)
42554255

4256+
def test_transpile_to_binary(self):
4257+
expr = self.parse_one("TO_BINARY('48454C50', 'HEX')", dialect="snowflake")
4258+
annotated = annotate_types(expr, dialect="snowflake")
4259+
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
4260+
4261+
expr = self.parse_one("TO_BINARY('48454C50')", dialect="snowflake")
4262+
annotated = annotate_types(expr, dialect="snowflake")
4263+
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
4264+
4265+
expr = self.parse_one("TO_BINARY('TEST', 'UTF-8')", dialect="snowflake")
4266+
annotated = annotate_types(expr, dialect="snowflake")
4267+
self.assertEqual(annotated.sql("duckdb"), "ENCODE('TEST')")
4268+
4269+
expr = self.parse_one("TO_BINARY('SEVMUA==', 'BASE64')", dialect="snowflake")
4270+
annotated = annotate_types(expr, dialect="snowflake")
4271+
self.assertEqual(annotated.sql("duckdb"), "FROM_BASE64('SEVMUA==')")
4272+
4273+
expr = self.parse_one("TRY_TO_BINARY('48454C50', 'HEX')", dialect="snowflake")
4274+
annotated = annotate_types(expr, dialect="snowflake")
4275+
self.assertEqual(annotated.sql("duckdb"), "TRY(UNHEX('48454C50'))")
4276+
4277+
expr = self.parse_one("TRY_TO_BINARY('48454C50')", dialect="snowflake")
4278+
annotated = annotate_types(expr, dialect="snowflake")
4279+
self.assertEqual(annotated.sql("duckdb"), "TRY(UNHEX('48454C50'))")
4280+
4281+
expr = self.parse_one("TRY_TO_BINARY('Hello', 'UTF-8')", dialect="snowflake")
4282+
annotated = annotate_types(expr, dialect="snowflake")
4283+
self.assertEqual(annotated.sql("duckdb"), "TRY(ENCODE('Hello'))")
4284+
4285+
expr = self.parse_one("TRY_TO_BINARY('SGVsbG8=', 'BASE64')", dialect="snowflake")
4286+
annotated = annotate_types(expr, dialect="snowflake")
4287+
self.assertEqual(annotated.sql("duckdb"), "TRY(FROM_BASE64('SGVsbG8='))")
4288+
42564289
def test_transpile_bitwise_ops(self):
42574290
# Binary bitwise operations
42584291
expr = self.parse_one("SELECT BITOR(x'FF', x'0F')", dialect="snowflake")

0 commit comments

Comments
 (0)