Skip to content

Commit e7b5d6f

Browse files
feat(snowflake)!: support transpilation of TRY_TO_BINARY from Snowflake to DuckDB (#6629)
* support transpilation of TRY_TO_BINARY from Snowflake to DuckDB * update tests
1 parent c17878a commit e7b5d6f

File tree

3 files changed

+57
-30
lines changed

3 files changed

+57
-30
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,29 +1686,41 @@ def zipf_sql(self: DuckDB.Generator, expression: exp.Zipf) -> str:
16861686

16871687
def tobinary_sql(self: DuckDB.Generator, expression: exp.ToBinary) -> str:
16881688
"""
1689-
TO_BINARY(value, format) transpilation if the return type is BINARY:
1689+
TO_BINARY and TRY_TO_BINARY transpilation:
16901690
- 'HEX': TO_BINARY('48454C50', 'HEX') → UNHEX('48454C50')
16911691
- 'UTF-8': TO_BINARY('TEST', 'UTF-8') → ENCODE('TEST')
16921692
- 'BASE64': TO_BINARY('SEVMUA==', 'BASE64') → FROM_BASE64('SEVMUA==')
16931693
1694-
format can be 'HEX', 'UTF-8' or 'BASE64'
1695-
return type can be either VARCHAR or BINARY
1694+
For TRY_TO_BINARY (safe=True), wrap with TRY():
1695+
- 'HEX': TRY_TO_BINARY('invalid', 'HEX') → TRY(UNHEX('invalid'))
16961696
"""
16971697
value = expression.this
16981698
format_arg = expression.args.get("format")
1699+
is_safe = expression.args.get("safe")
16991700

17001701
fmt = "HEX"
17011702
if format_arg:
17021703
fmt = format_arg.name.upper()
17031704

17041705
if expression.is_type(exp.DataType.Type.BINARY):
17051706
if fmt == "UTF-8":
1706-
return self.func("ENCODE", value)
1707-
if fmt == "BASE64":
1708-
return self.func("FROM_BASE64", value)
1707+
result = self.func("ENCODE", value)
1708+
elif fmt == "BASE64":
1709+
result = self.func("FROM_BASE64", value)
1710+
elif fmt == "HEX":
1711+
result = self.func("UNHEX", value)
1712+
else:
1713+
if is_safe:
1714+
return self.sql(exp.null())
1715+
else:
1716+
self.unsupported(f"format {fmt} is not supported")
1717+
result = self.func("TO_BINARY", value)
1718+
1719+
# Wrap with TRY() for TRY_TO_BINARY
1720+
if is_safe:
1721+
result = self.func("TRY", result)
17091722

1710-
# Hex
1711-
return self.func("UNHEX", value)
1723+
return result
17121724

17131725
# Fallback, which needs to be updated if want to support transpilation from other dialects than Snowflake
17141726
return self.func("TO_BINARY", value)

tests/dialects/test_duckdb.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -724,23 +724,6 @@ def test_duckdb(self):
724724
"duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)",
725725
},
726726
)
727-
728-
expr = self.parse_one("TO_BINARY('48454C50', 'HEX')", dialect="snowflake")
729-
annotated = annotate_types(expr, dialect="snowflake")
730-
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
731-
732-
expr = self.parse_one("TO_BINARY('48454C50')", dialect="snowflake")
733-
annotated = annotate_types(expr, dialect="snowflake")
734-
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
735-
736-
expr = self.parse_one("TO_BINARY('TEST', 'UTF-8')", dialect="snowflake")
737-
annotated = annotate_types(expr, dialect="snowflake")
738-
self.assertEqual(annotated.sql("duckdb"), "ENCODE('TEST')")
739-
740-
expr = self.parse_one("TO_BINARY('SEVMUA==', 'BASE64')", dialect="snowflake")
741-
annotated = annotate_types(expr, dialect="snowflake")
742-
self.assertEqual(annotated.sql("duckdb"), "FROM_BASE64('SEVMUA==')")
743-
744727
self.validate_all(
745728
"[0, 1, 2]",
746729
read={

tests/dialects/test_snowflake.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,6 @@ def test_snowflake(self):
353353
self.validate_identity("TO_DECFLOAT('1,234.56', '999,999.99')")
354354
self.validate_identity("TRY_TO_DECFLOAT('123.456')")
355355
self.validate_identity("TRY_TO_DECFLOAT('1,234.56', '999,999.99')")
356-
self.validate_identity("TRY_TO_BINARY('48656C6C6F')")
357-
self.validate_identity("TRY_TO_BINARY('48656C6C6F', 'HEX')")
358356
self.validate_all(
359357
"TRY_TO_BOOLEAN('true')",
360358
write={
@@ -507,9 +505,6 @@ def test_snowflake(self):
507505
self.validate_identity("SELECT a, exclude, b FROM xxx")
508506
self.validate_identity("SELECT ARRAY_SORT(x, TRUE, FALSE)")
509507
self.validate_identity("SELECT BOOLXOR_AGG(col) FROM tbl")
510-
self.validate_identity("SELECT TO_BINARY('C2')")
511-
self.validate_identity("SELECT TO_BINARY('C2', 'HEX')")
512-
self.validate_identity("SELECT TO_BINARY('café', 'UTF-8')")
513508
self.validate_identity(
514509
"SELECT PERCENTILE_DISC(0.9) WITHIN GROUP (ORDER BY col) OVER (PARTITION BY category)"
515510
)
@@ -4326,6 +4321,43 @@ def test_round(self):
43264321
},
43274322
)
43284323

4324+
def test_to_binary(self):
4325+
expr = self.validate_identity("TO_BINARY('48454C50', 'HEX')")
4326+
annotated = annotate_types(expr, dialect="snowflake")
4327+
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
4328+
4329+
expr = self.validate_identity("TO_BINARY('48454C50')")
4330+
annotated = annotate_types(expr, dialect="snowflake")
4331+
self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')")
4332+
4333+
expr = self.validate_identity("TO_BINARY('TEST', 'UTF-8')")
4334+
annotated = annotate_types(expr, dialect="snowflake")
4335+
self.assertEqual(annotated.sql("duckdb"), "ENCODE('TEST')")
4336+
4337+
expr = self.validate_identity("TO_BINARY('SEVMUA==', 'BASE64')")
4338+
annotated = annotate_types(expr, dialect="snowflake")
4339+
self.assertEqual(annotated.sql("duckdb"), "FROM_BASE64('SEVMUA==')")
4340+
4341+
expr = self.validate_identity("TRY_TO_BINARY('48454C50', 'HEX')")
4342+
annotated = annotate_types(expr, dialect="snowflake")
4343+
self.assertEqual(annotated.sql("duckdb"), "TRY(UNHEX('48454C50'))")
4344+
4345+
expr = self.validate_identity("TRY_TO_BINARY('48454C50')")
4346+
annotated = annotate_types(expr, dialect="snowflake")
4347+
self.assertEqual(annotated.sql("duckdb"), "TRY(UNHEX('48454C50'))")
4348+
4349+
expr = self.validate_identity("TRY_TO_BINARY('Hello', 'UTF-8')")
4350+
annotated = annotate_types(expr, dialect="snowflake")
4351+
self.assertEqual(annotated.sql("duckdb"), "TRY(ENCODE('Hello'))")
4352+
4353+
expr = self.validate_identity("TRY_TO_BINARY('SGVsbG8=', 'BASE64')")
4354+
annotated = annotate_types(expr, dialect="snowflake")
4355+
self.assertEqual(annotated.sql("duckdb"), "TRY(FROM_BASE64('SGVsbG8='))")
4356+
4357+
expr = self.validate_identity("TRY_TO_BINARY('Hello', 'UTF-16')")
4358+
annotated = annotate_types(expr, dialect="snowflake")
4359+
self.assertEqual(annotated.sql("duckdb"), "NULL")
4360+
43294361
def test_transpile_bitwise_ops(self):
43304362
# Binary bitwise operations
43314363
expr = self.parse_one("SELECT BITOR(x'FF', x'0F')", dialect="snowflake")

0 commit comments

Comments
 (0)