Skip to content

Commit d71dec6

Browse files
updates
1 parent a458621 commit d71dec6

File tree

2 files changed

+58
-20
lines changed

2 files changed

+58
-20
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@
4848
from sqlglot.tokens import TokenType
4949
from sqlglot.parser import binary_range_parser
5050

51+
# Bitwise operations that return binary when given binary input
52+
BITWISE_BINARY_OPS = (
53+
exp.BitwiseLeftShift,
54+
exp.BitwiseRightShift,
55+
exp.BitwiseAnd,
56+
exp.BitwiseOr,
57+
exp.BitwiseXor,
58+
)
59+
5160
# Regex to detect time zones in timestamps of the form [+|-]TT[:tt]
5261
# The pattern matches timezone offsets that appear after the time portion
5362
TIMEZONE_PATTERN = re.compile(r":\d{2}.*?[+\-]\d{2}(?::\d{2})?")
@@ -527,6 +536,8 @@ def _cast_to_bit(arg: exp.Expression) -> exp.Expression:
527536

528537
if isinstance(arg, exp.HexString):
529538
arg = exp.Unhex(this=exp.Literal.string(arg.this))
539+
elif isinstance(arg, exp.Cast) and isinstance(arg.this, exp.HexString):
540+
arg = exp.Unhex(this=exp.Literal.string(arg.this.this))
530541

531542
return exp.cast(arg, exp.DataType.Type.BIT)
532543

@@ -691,30 +702,60 @@ def _boolxor_agg_sql(self: DuckDB.Generator, expression: exp.BoolxorAgg) -> str:
691702
)
692703

693704

694-
def _prepare_bitshift_for_duckdb(expression: exp.Expression) -> exp.Expression:
705+
def _bitshift_sql(
706+
self: DuckDB.Generator, expression: exp.BitwiseLeftShift | exp.BitwiseRightShift
707+
) -> str:
695708
"""
696-
Transform bitwise shift expressions for DuckDB by injecting INT128 casts.
709+
Transform bitshift expressions for DuckDB by injecting BIT/INT128 casts.
697710
698711
DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast
699-
them to INT128 for integer arithmetic.
712+
them to BIT for the operation, then cast the result back to the original type.
700713
701714
Note: Assumes type annotation has been applied with the source dialect.
702715
"""
703-
# Unwrap BINARY/VARBINARY casts that DuckDB can't handle
704-
if isinstance(expression.this, exp.Cast) and _is_binary(expression.this):
705-
expression.this.replace(expression.this.this)
716+
operator = "<<" if isinstance(expression, exp.BitwiseLeftShift) else ">>"
717+
original_type = None
718+
this = expression.this
706719

707-
# Check if the input is a BLOB/BINARY type (using is_type) or requires INT128
708-
# If so, cast to INT128 for integer arithmetic
709-
if _is_binary(expression.this) or expression.args.get("requires_int128"):
710-
expression.this.replace(exp.cast(expression.this, exp.DataType.Type.INT128))
720+
# Check if input is binary:
721+
# 1. Direct binary type annotation on expression.this
722+
# 2. Chained bitshift where inner operation's input is binary
723+
# 3. CAST to binary type (e.g., X'FF'::BINARY)
724+
is_binary_input = (
725+
_is_binary(this)
726+
or (isinstance(this.this, exp.Expression) and _is_binary(this.this))
727+
or (isinstance(this, exp.Cast) and _is_binary(this))
728+
)
729+
730+
# Deal with binary separately, remember the original type, cast back later, etc.
731+
if is_binary_input:
732+
original_type = this.to if isinstance(this, exp.Cast) else exp.DataType.build("BLOB")
733+
734+
# for chained binary operators
735+
if isinstance(this, exp.Binary):
736+
expression.set("this", exp.cast(this, exp.DataType.Type.BIT))
737+
else:
738+
expression.set("this", _cast_to_bit(this))
739+
740+
# Remove the flag for binary otherwise the final cast will get wrapped in an extra INT128 cast
741+
expression.args.pop("requires_int128")
742+
743+
# cast to INT128 if required (e.g. coming from Snowflake)
744+
elif expression.args.get("requires_int128"):
745+
this.replace(exp.cast(this, exp.DataType.Type.INT128))
746+
747+
result_sql = self.binary(expression, operator)
711748

712749
# Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue
713750
# DuckDB parses: a << b | c << d as (a << b | c) << d
714-
if isinstance(expression.parent, (exp.BitwiseAnd, exp.BitwiseOr, exp.BitwiseXor)):
715-
return exp.paren(expression, copy=False)
751+
if isinstance(expression.parent, exp.Binary):
752+
result_sql = self.sql(exp.Paren(this=result_sql))
753+
754+
# Cast the result back to the original type
755+
if original_type:
756+
result_sql = self.sql(exp.Cast(this=result_sql, to=original_type))
716757

717-
return expression
758+
return result_sql
718759

719760

720761
def _scale_rounding_sql(
@@ -1256,10 +1297,10 @@ class Generator(generator.Generator):
12561297
),
12571298
exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"),
12581299
exp.BitwiseAndAgg: _bitwise_agg_sql,
1259-
exp.BitwiseLeftShift: transforms.preprocess([_prepare_bitshift_for_duckdb]),
1300+
exp.BitwiseLeftShift: _bitshift_sql,
12601301
exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"),
12611302
exp.BitwiseOrAgg: _bitwise_agg_sql,
1262-
exp.BitwiseRightShift: transforms.preprocess([_prepare_bitshift_for_duckdb]),
1303+
exp.BitwiseRightShift: _bitshift_sql,
12631304
exp.BitwiseXorAgg: _bitwise_agg_sql,
12641305
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
12651306
exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),

tests/dialects/test_snowflake.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,7 +1711,7 @@ def test_snowflake(self):
17111711
)
17121712
self.validate_all(
17131713
"SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)",
1714-
write={"duckdb": "SELECT CAST(255 AS INT128) << 4"},
1714+
write={"duckdb": "SELECT CAST(CAST(CAST(255 AS BLOB) AS BIT) << 4 AS BLOB)"},
17151715
)
17161716
self.validate_all(
17171717
"SELECT BITSHIFTLEFT(X'FF', 4)",
@@ -1722,10 +1722,7 @@ def test_snowflake(self):
17221722
"SELECT BITSHIFTRIGHT(1024, 4)",
17231723
write={"duckdb": "SELECT CAST(1024 AS INT128) >> 4"},
17241724
)
1725-
self.validate_all(
1726-
"SELECT BITSHIFTRIGHT(CAST(255 AS BINARY), 4)",
1727-
write={"duckdb": "SELECT CAST(255 AS INT128) >> 4"},
1728-
)
1725+
17291726
self.validate_all(
17301727
"SELECT BITSHIFTRIGHT(X'FF', 4)",
17311728
write={"duckdb": "SELECT CAST(255 AS INT128) >> 4"},

0 commit comments

Comments
 (0)