|
48 | 48 | from sqlglot.tokens import TokenType |
49 | 49 | from sqlglot.parser import binary_range_parser |
50 | 50 |
|
| 51 | +# Bitwise operations that return binary when given binary input |
| 52 | +BITWISE_BINARY_OPS = ( |
| 53 | + exp.BitwiseLeftShift, |
| 54 | + exp.BitwiseRightShift, |
| 55 | + exp.BitwiseAnd, |
| 56 | + exp.BitwiseOr, |
| 57 | + exp.BitwiseXor, |
| 58 | +) |
| 59 | + |
51 | 60 | # Regex to detect time zones in timestamps of the form [+|-]TT[:tt] |
52 | 61 | # The pattern matches timezone offsets that appear after the time portion |
53 | 62 | TIMEZONE_PATTERN = re.compile(r":\d{2}.*?[+\-]\d{2}(?::\d{2})?") |
@@ -527,6 +536,8 @@ def _cast_to_bit(arg: exp.Expression) -> exp.Expression: |
527 | 536 |
|
528 | 537 | if isinstance(arg, exp.HexString): |
529 | 538 | arg = exp.Unhex(this=exp.Literal.string(arg.this)) |
| 539 | + elif isinstance(arg, exp.Cast) and isinstance(arg.this, exp.HexString): |
| 540 | + arg = exp.Unhex(this=exp.Literal.string(arg.this.this)) |
530 | 541 |
|
531 | 542 | return exp.cast(arg, exp.DataType.Type.BIT) |
532 | 543 |
|
@@ -691,30 +702,60 @@ def _boolxor_agg_sql(self: DuckDB.Generator, expression: exp.BoolxorAgg) -> str: |
691 | 702 | ) |
692 | 703 |
|
693 | 704 |
|
694 | | -def _prepare_bitshift_for_duckdb(expression: exp.Expression) -> exp.Expression: |
| 705 | +def _bitshift_sql( |
| 706 | + self: DuckDB.Generator, expression: exp.BitwiseLeftShift | exp.BitwiseRightShift |
| 707 | +) -> str: |
695 | 708 | """ |
696 | | - Transform bitwise shift expressions for DuckDB by injecting INT128 casts. |
| 709 | + Transform bitshift expressions for DuckDB by injecting BIT/INT128 casts. |
697 | 710 |
|
698 | 711 | DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast |
699 | | - them to INT128 for integer arithmetic. |
| 712 | + them to BIT for the operation, then cast the result back to the original type. |
700 | 713 |
|
701 | 714 | Note: Assumes type annotation has been applied with the source dialect. |
702 | 715 | """ |
703 | | - # Unwrap BINARY/VARBINARY casts that DuckDB can't handle |
704 | | - if isinstance(expression.this, exp.Cast) and _is_binary(expression.this): |
705 | | - expression.this.replace(expression.this.this) |
| 716 | + operator = "<<" if isinstance(expression, exp.BitwiseLeftShift) else ">>" |
| 717 | + original_type = None |
| 718 | + this = expression.this |
706 | 719 |
|
707 | | - # Check if the input is a BLOB/BINARY type (using is_type) or requires INT128 |
708 | | - # If so, cast to INT128 for integer arithmetic |
709 | | - if _is_binary(expression.this) or expression.args.get("requires_int128"): |
710 | | - expression.this.replace(exp.cast(expression.this, exp.DataType.Type.INT128)) |
| 720 | + # Check if input is binary: |
| 721 | + # 1. Direct binary type annotation on expression.this |
| 722 | + # 2. Chained bitshift where inner operation's input is binary |
| 723 | + # 3. CAST to binary type (e.g., X'FF'::BINARY) |
| 724 | + is_binary_input = ( |
| 725 | + _is_binary(this) |
| 726 | + or (isinstance(this.this, exp.Expression) and _is_binary(this.this)) |
| 727 | + or (isinstance(this, exp.Cast) and _is_binary(this)) |
| 728 | + ) |
| 729 | + |
| 730 | + # Deal with binary separately, remember the original type, cast back later, etc. |
| 731 | + if is_binary_input: |
| 732 | + original_type = this.to if isinstance(this, exp.Cast) else exp.DataType.build("BLOB") |
| 733 | + |
| 734 | + # for chained binary operators |
| 735 | + if isinstance(this, exp.Binary): |
| 736 | + expression.set("this", exp.cast(this, exp.DataType.Type.BIT)) |
| 737 | + else: |
| 738 | + expression.set("this", _cast_to_bit(this)) |
| 739 | + |
| 740 | + # Remove the flag for binary otherwise the final cast will get wrapped in an extra INT128 cast |
| 741 | + expression.args.pop("requires_int128") |
| 742 | + |
| 743 | + # cast to INT128 if required (e.g. coming from Snowflake) |
| 744 | + elif expression.args.get("requires_int128"): |
| 745 | + this.replace(exp.cast(this, exp.DataType.Type.INT128)) |
| 746 | + |
| 747 | + result_sql = self.binary(expression, operator) |
711 | 748 |
|
712 | 749 | # Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue |
713 | 750 | # DuckDB parses: a << b | c << d as (a << b | c) << d |
714 | | - if isinstance(expression.parent, (exp.BitwiseAnd, exp.BitwiseOr, exp.BitwiseXor)): |
715 | | - return exp.paren(expression, copy=False) |
| 751 | + if isinstance(expression.parent, exp.Binary): |
| 752 | + result_sql = self.sql(exp.Paren(this=result_sql)) |
| 753 | + |
| 754 | + # Cast the result back to the original type |
| 755 | + if original_type: |
| 756 | + result_sql = self.sql(exp.Cast(this=result_sql, to=original_type)) |
716 | 757 |
|
717 | | - return expression |
| 758 | + return result_sql |
718 | 759 |
|
719 | 760 |
|
720 | 761 | def _scale_rounding_sql( |
@@ -1256,10 +1297,10 @@ class Generator(generator.Generator): |
1256 | 1297 | ), |
1257 | 1298 | exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"), |
1258 | 1299 | exp.BitwiseAndAgg: _bitwise_agg_sql, |
1259 | | - exp.BitwiseLeftShift: transforms.preprocess([_prepare_bitshift_for_duckdb]), |
| 1300 | + exp.BitwiseLeftShift: _bitshift_sql, |
1260 | 1301 | exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"), |
1261 | 1302 | exp.BitwiseOrAgg: _bitwise_agg_sql, |
1262 | | - exp.BitwiseRightShift: transforms.preprocess([_prepare_bitshift_for_duckdb]), |
| 1303 | + exp.BitwiseRightShift: _bitshift_sql, |
1263 | 1304 | exp.BitwiseXorAgg: _bitwise_agg_sql, |
1264 | 1305 | exp.CommentColumnConstraint: no_comment_column_constraint_sql, |
1265 | 1306 | exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"), |
|
0 commit comments