Skip to content

Commit 9f96195

Browse files
feat(snowflake): Add BITSHIFTLEFT/BITSHIFTRIGHT transpilation to DuckDB with INT128 casts and precedence fixes
1 parent d322fa6 commit 9f96195

File tree

4 files changed

+89
-27
lines changed

4 files changed

+89
-27
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -642,34 +642,47 @@ def _prepare_bitshift_for_duckdb(expression: exp.Expression) -> exp.Expression:
642642
"""
643643
Transform bitwise shift expressions for DuckDB by injecting INT128 casts.
644644
645-
Modifies the AST in-place to:
646-
1. Unwrap BINARY/BLOB casts and replace with INT128 cast
647-
2. Mark HexString literals as integers (sets is_integer flag for proper rendering)
648-
3. Apply INT128 cast if requires_int128=True (from Snowflake)
645+
DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast
646+
them to INT128 for integer arithmetic.
649647
650-
After transformation, falls back to the base binary generator.
648+
Note: Assumes type annotation has been applied with the source dialect.
651649
"""
652-
needs_int128 = False
653-
654-
# Check if left operand is a BINARY/BLOB cast - unwrap it
650+
# Unwrap BINARY/VARBINARY casts that DuckDB can't handle
655651
if isinstance(expression.this, exp.Cast) and _is_binary(expression.this):
656-
# Replace Cast(value, BINARY) with just the inner value
657652
expression.this.replace(expression.this.this)
658-
needs_int128 = True
659-
660-
# Mark HexString as integer so DuckDB renders it as decimal instead of BLOB
661-
if isinstance(expression.this, exp.HexString):
662-
expression.this.set("is_integer", True)
663-
needs_int128 = True
664653

665-
# Apply INT128 cast if needed
666-
if needs_int128 or expression.args.get("requires_int128"):
654+
# Check if the input is a BLOB/BINARY type (using is_type) or requires INT128
655+
# If so, cast to INT128 for integer arithmetic
656+
if _is_binary(expression.this) or expression.args.get("requires_int128"):
667657
expression.this.replace(exp.cast(expression.this, exp.DataType.Type.INT128))
668658

659+
# Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue
660+
# DuckDB parses: a << b | c << d as (a << b | c) << d
661+
if isinstance(expression.parent, (exp.BitwiseAnd, exp.BitwiseOr, exp.BitwiseXor)):
662+
return exp.paren(expression, copy=False)
663+
669664
return expression
670665

671666

672-
def _floor_sql(self: DuckDB.Generator, expression: exp.Floor) -> str:
667+
def _scale_rounding_sql(
668+
self: DuckDB.Generator,
669+
expression: exp.Expression,
670+
rounding_func: type[exp.Expression],
671+
) -> str | None:
672+
"""
673+
Handle scale parameter transformation for rounding functions.
674+
675+
DuckDB doesn't support the scale parameter for certain functions (e.g., FLOOR, CEIL),
676+
so we transform: FUNC(x, n) to ROUND(FUNC(x * 10^n) / 10^n, n)
677+
678+
Args:
679+
self: The DuckDB generator instance
680+
expression: The expression to transform (must have 'this', 'decimals', and 'to' args)
681+
rounding_func: The rounding function class to use in the transformation
682+
683+
Returns:
684+
The transformed SQL string if decimals parameter exists, None otherwise
685+
"""
673686
decimals = expression.args.get("decimals")
674687

675688
if decimals is None or expression.args.get("to") is not None:
@@ -1190,8 +1203,10 @@ class Generator(generator.Generator):
11901203
),
11911204
exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"),
11921205
exp.BitwiseAndAgg: _bitwise_agg_sql,
1206+
exp.BitwiseLeftShift: transforms.preprocess([_prepare_bitshift_for_duckdb]),
11931207
exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"),
11941208
exp.BitwiseOrAgg: _bitwise_agg_sql,
1209+
exp.BitwiseRightShift: transforms.preprocess([_prepare_bitshift_for_duckdb]),
11951210
exp.BitwiseXorAgg: _bitwise_agg_sql,
11961211
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
11971212
exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),

sqlglot/dialects/snowflake.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,16 @@ def _builder(args: t.List) -> B | exp.Anonymous:
163163
)
164164
return exp.Anonymous(this=name, expressions=args)
165165

166+
result = binary_from_function(expr_type)(args)
167+
166168
# Snowflake specifies INT128 for bitwise shifts
167169
if expr_type in (exp.BitwiseLeftShift, exp.BitwiseRightShift):
168-
result = binary_from_function(expr_type)(args)
169170
result.set("requires_int128", True)
170-
return result
171+
# Mark HexStrings as integers for proper rendering in target dialects
172+
for hexstr in result.find_all(exp.HexString):
173+
hexstr.set("is_integer", True)
171174

172-
return binary_from_function(expr_type)(args)
175+
return result
173176

174177
return _builder
175178

tests/dialects/test_snowflake.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,25 +1692,45 @@ def test_snowflake(self):
16921692
self.validate_identity("SELECT BITSHIFTRIGHT(a, 1)")
16931693
self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)")
16941694

1695-
# Test bitshift transpilation to DuckDB with INT128 casts
1695+
# duckdb has an order of operations precedence issue with bitshift and bitwise operators
1696+
self.validate_all(
1697+
"SELECT BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8))",
1698+
write={"duckdb": "SELECT (CAST(5 AS INT128) << 16) | (CAST(3 AS INT128) << 8)"},
1699+
)
1700+
self.validate_all(
1701+
"SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))",
1702+
write={"duckdb": "SELECT (CAST(255 AS INT128) << 4) & (CAST(15 AS INT128) << 2)"},
1703+
)
16961704
self.validate_all(
16971705
"SELECT BITSHIFTLEFT(255, 4)",
16981706
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
16991707
)
17001708
self.validate_all(
1701-
"SELECT BITSHIFTRIGHT(1024, 2)",
1702-
write={"duckdb": "SELECT CAST(1024 AS INT128) >> 2"},
1709+
"SELECT BITSHIFTLEFT(255, 4)",
1710+
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
17031711
)
1704-
# Test BINARY cast unwrapping
17051712
self.validate_all(
1706-
"SELECT BITSHIFTLEFT(X'FF'::BINARY, 4)",
1713+
"SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)",
17071714
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
17081715
)
1709-
# Test HexString with is_integer flag
17101716
self.validate_all(
17111717
"SELECT BITSHIFTLEFT(X'FF', 4)",
17121718
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
17131719
)
1720+
1721+
self.validate_all(
1722+
"SELECT BITSHIFTRIGHT(1024, 4)",
1723+
write={"duckdb": "SELECT CAST(1024 AS INT128) >> 4"},
1724+
)
1725+
self.validate_all(
1726+
"SELECT BITSHIFTRIGHT(CAST(255 AS BINARY), 4)",
1727+
write={"duckdb": "SELECT CAST(255 AS INT128) >> 4"},
1728+
)
1729+
self.validate_all(
1730+
"SELECT BITSHIFTRIGHT(X'FF', 4)",
1731+
write={"duckdb": "SELECT CAST(255 AS INT128) >> 4"},
1732+
)
1733+
17141734
self.validate_all(
17151735
"OCTET_LENGTH('A')",
17161736
read={

tests/fixtures/optimizer/annotate_types.sql

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,27 @@ INT;
162162
# dialect: snowflake
163163
BITSHIFTRIGHT(1024, 2);
164164
INT;
165+
166+
# dialect: snowflake
167+
BITSHIFTLEFT(CAST(255 AS BINARY), 4);
168+
BINARY;
169+
170+
# dialect: snowflake
171+
BITSHIFTRIGHT(CAST(255 AS BINARY), 4);
172+
BINARY;
173+
174+
# dialect: snowflake
175+
BITSHIFTLEFT(X'FF', 4);
176+
BINARY;
177+
178+
# dialect: snowflake
179+
BITSHIFTRIGHT(X'FF', 4);
180+
BINARY;
181+
182+
# dialect: snowflake
183+
BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8));
184+
INT;
185+
186+
# dialect: snowflake
187+
BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2));
188+
INT;

0 commit comments

Comments
 (0)