Skip to content

Commit a458621

Browse files
feat(snowflake): Add BITSHIFTLEFT/BITSHIFTRIGHT transpilation to DuckDB with INT128 casts and precedence fixes
1 parent 870d600 commit a458621

File tree

5 files changed

+112
-3
lines changed

5 files changed

+112
-3
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,32 @@ def _boolxor_agg_sql(self: DuckDB.Generator, expression: exp.BoolxorAgg) -> str:
691691
)
692692

693693

694+
def _prepare_bitshift_for_duckdb(expression: exp.Expression) -> exp.Expression:
695+
"""
696+
Transform bitwise shift expressions for DuckDB by injecting INT128 casts.
697+
698+
DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast
699+
them to INT128 for integer arithmetic.
700+
701+
Note: Assumes type annotation has been applied with the source dialect.
702+
"""
703+
# Unwrap BINARY/VARBINARY casts that DuckDB can't handle
704+
if isinstance(expression.this, exp.Cast) and _is_binary(expression.this):
705+
expression.this.replace(expression.this.this)
706+
707+
# Check if the input is a BLOB/BINARY type (using is_type) or requires INT128
708+
# If so, cast to INT128 for integer arithmetic
709+
if _is_binary(expression.this) or expression.args.get("requires_int128"):
710+
expression.this.replace(exp.cast(expression.this, exp.DataType.Type.INT128))
711+
712+
# Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue
713+
# DuckDB parses: a << b | c << d as (a << b | c) << d
714+
if isinstance(expression.parent, (exp.BitwiseAnd, exp.BitwiseOr, exp.BitwiseXor)):
715+
return exp.paren(expression, copy=False)
716+
717+
return expression
718+
719+
694720
def _scale_rounding_sql(
695721
self: DuckDB.Generator,
696722
expression: exp.Expression,
@@ -1230,8 +1256,10 @@ class Generator(generator.Generator):
12301256
),
12311257
exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"),
12321258
exp.BitwiseAndAgg: _bitwise_agg_sql,
1259+
exp.BitwiseLeftShift: transforms.preprocess([_prepare_bitshift_for_duckdb]),
12331260
exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"),
12341261
exp.BitwiseOrAgg: _bitwise_agg_sql,
1262+
exp.BitwiseRightShift: transforms.preprocess([_prepare_bitshift_for_duckdb]),
12351263
exp.BitwiseXorAgg: _bitwise_agg_sql,
12361264
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
12371265
exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),

sqlglot/dialects/snowflake.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,16 @@ def _builder(args: t.List) -> B | exp.Anonymous:
163163
)
164164
return exp.Anonymous(this=name, expressions=args)
165165

166-
return binary_from_function(expr_type)(args)
166+
result = binary_from_function(expr_type)(args)
167+
168+
# Snowflake specifies INT128 for bitwise shifts
169+
if expr_type in (exp.BitwiseLeftShift, exp.BitwiseRightShift):
170+
result.set("requires_int128", True)
171+
# Mark HexStrings as integers for proper rendering in target dialects
172+
for hexstr in result.find_all(exp.HexString):
173+
hexstr.set("is_integer", True)
174+
175+
return result
167176

168177
return _builder
169178

sqlglot/expressions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5143,15 +5143,15 @@ class BitwiseAnd(Binary):
51435143

51445144

51455145
class BitwiseLeftShift(Binary):
5146-
pass
5146+
arg_types = {"this": True, "expression": True, "requires_int128": False}
51475147

51485148

51495149
class BitwiseOr(Binary):
51505150
arg_types = {"this": True, "expression": True, "padside": False}
51515151

51525152

51535153
class BitwiseRightShift(Binary):
5154-
pass
5154+
arg_types = {"this": True, "expression": True, "requires_int128": False}
51555155

51565156

51575157
class BitwiseXor(Binary):

tests/dialects/test_snowflake.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,46 @@ def test_snowflake(self):
16911691
self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)")
16921692
self.validate_identity("SELECT BITSHIFTRIGHT(a, 1)")
16931693
self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)")
1694+
1695+
# duckdb has an order of operations precedence issue with bitshift and bitwise operators
1696+
self.validate_all(
1697+
"SELECT BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8))",
1698+
write={"duckdb": "SELECT (CAST(5 AS INT128) << 16) | (CAST(3 AS INT128) << 8)"},
1699+
)
1700+
self.validate_all(
1701+
"SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))",
1702+
write={"duckdb": "SELECT (CAST(255 AS INT128) << 4) & (CAST(15 AS INT128) << 2)"},
1703+
)
1704+
self.validate_all(
1705+
"SELECT BITSHIFTLEFT(255, 4)",
1706+
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
1707+
)
1708+
self.validate_all(
1709+
"SELECT BITSHIFTLEFT(255, 4)",
1710+
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
1711+
)
1712+
self.validate_all(
1713+
"SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)",
1714+
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
1715+
)
1716+
self.validate_all(
1717+
"SELECT BITSHIFTLEFT(X'FF', 4)",
1718+
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
1719+
)
1720+
1721+
self.validate_all(
1722+
"SELECT BITSHIFTRIGHT(1024, 4)",
1723+
write={"duckdb": "SELECT CAST(1024 AS INT128) >> 4"},
1724+
)
1725+
self.validate_all(
1726+
"SELECT BITSHIFTRIGHT(CAST(255 AS BINARY), 4)",
1727+
write={"duckdb": "SELECT CAST(255 AS INT128) >> 4"},
1728+
)
1729+
self.validate_all(
1730+
"SELECT BITSHIFTRIGHT(X'FF', 4)",
1731+
write={"duckdb": "SELECT CAST(255 AS INT128) >> 4"},
1732+
)
1733+
16941734
self.validate_all(
16951735
"OCTET_LENGTH('A')",
16961736
read={

tests/fixtures/optimizer/annotate_types.sql

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,35 @@ FLOAT64;
186186
# dialect: bigquery
187187
CASE WHEN x < y THEN CAST(3.5 AS BIGNUMERIC) WHEN x > y THEN 3/10 ELSE 2 END;
188188
FLOAT64;
189+
190+
# dialect: snowflake
191+
BITSHIFTLEFT(255, 4);
192+
INT;
193+
194+
# dialect: snowflake
195+
BITSHIFTRIGHT(1024, 2);
196+
INT;
197+
198+
# dialect: snowflake
199+
BITSHIFTLEFT(CAST(255 AS BINARY), 4);
200+
BINARY;
201+
202+
# dialect: snowflake
203+
BITSHIFTRIGHT(CAST(255 AS BINARY), 4);
204+
BINARY;
205+
206+
# dialect: snowflake
207+
BITSHIFTLEFT(X'FF', 4);
208+
BINARY;
209+
210+
# dialect: snowflake
211+
BITSHIFTRIGHT(X'FF', 4);
212+
BINARY;
213+
214+
# dialect: snowflake
215+
BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8));
216+
INT;
217+
218+
# dialect: snowflake
219+
BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2));
220+
INT;

0 commit comments

Comments
 (0)