diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index cab31571e4..70fecfb17a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -563,7 +563,7 @@ def _cast_to_boolean(arg: t.Optional[exp.Expression]) -> t.Optional[exp.Expressi def _is_binary(arg: exp.Expression) -> bool: - return arg.is_type( + return isinstance(arg, exp.HexString) or arg.is_type( exp.DataType.Type.BINARY, exp.DataType.Type.VARBINARY, exp.DataType.Type.BLOB, @@ -749,6 +749,48 @@ def _boolxor_agg_sql(self: DuckDB.Generator, expression: exp.BoolxorAgg) -> str: ) +def _bitshift_sql( + self: DuckDB.Generator, expression: exp.BitwiseLeftShift | exp.BitwiseRightShift +) -> str: + """ + Transform bitshift expressions for DuckDB by injecting BIT/INT128 casts. + + DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast + them to BIT for the operation, then cast the result back to the original type. + + Note: Assumes type annotation has been applied with the source dialect. + """ + operator = "<<" if isinstance(expression, exp.BitwiseLeftShift) else ">>" + original_type = None + this = expression.this + + # Ensure type annotation is available for nested expressions + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + this = annotate_types(this, dialect=self.dialect) + + # Deal with binary separately, remember the original type, cast back later + if _is_binary(this): + original_type = this.to if isinstance(this, exp.Cast) else exp.DataType.build("BLOB") + expression.set("this", exp.cast(this, exp.DataType.Type.BIT)) + elif expression.args.get("requires_int128"): + this.replace(exp.cast(this, exp.DataType.Type.INT128)) + + result_sql = self.binary(expression, operator) + + # Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue + # DuckDB parses: a << b | c << d as (a << b | c) << d + if isinstance(expression.parent, exp.Binary): + result_sql = self.sql(exp.Paren(this=result_sql)) + + # Cast the result back to the original type + if original_type: + result_sql = self.sql(exp.Cast(this=result_sql, to=original_type)) + + return result_sql + + def _scale_rounding_sql( self: DuckDB.Generator, expression: exp.Expression, @@ -1288,8 +1330,10 @@ class Generator(generator.Generator): ), exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"), exp.BitwiseAndAgg: _bitwise_agg_sql, + exp.BitwiseLeftShift: _bitshift_sql, exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"), exp.BitwiseOrAgg: _bitwise_agg_sql, + exp.BitwiseRightShift: _bitshift_sql, exp.BitwiseXorAgg: _bitwise_agg_sql, exp.CommentColumnConstraint: no_comment_column_constraint_sql, exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"), @@ -2229,16 +2273,8 @@ def format_sql(self, expression: exp.Format) -> str: def hexstring_sql( self, expression: exp.HexString, binary_function_repr: t.Optional[str] = None ) -> str: - from_hex = super().hexstring_sql(expression, binary_function_repr="FROM_HEX") - - if expression.args.get("is_integer"): - return from_hex - - # `from_hex` has transpiled x'ABCD' (BINARY) to DuckDB's '\xAB\xCD' (BINARY) - # `to_hex` & CASTing transforms it to "ABCD" (BINARY) to match representation - to_hex = exp.cast(self.func("TO_HEX", from_hex), exp.DataType.Type.BLOB) - - return self.sql(to_hex) + # UNHEX('FF') correctly produces blob \xFF in DuckDB + return super().hexstring_sql(expression, binary_function_repr="UNHEX") def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str: unit = unit_to_str(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index d3e4ada469..1db062ea5c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -163,7 +163,13 @@ def _builder(args: t.List) -> B | exp.Anonymous: ) return exp.Anonymous(this=name, expressions=args) - return binary_from_function(expr_type)(args) + result = binary_from_function(expr_type)(args) + + # Snowflake specifies INT128 for bitwise shifts + if expr_type in (exp.BitwiseLeftShift, exp.BitwiseRightShift): + result.set("requires_int128", True) + + return result return _builder diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 24f798ba7c..6477a225b6 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5143,7 +5143,7 @@ class BitwiseAnd(Binary): class BitwiseLeftShift(Binary): - pass + arg_types = {"this": True, "expression": True, "requires_int128": False} class BitwiseOr(Binary): @@ -5151,7 +5151,7 @@ class BitwiseOr(Binary): class BitwiseRightShift(Binary): - pass + arg_types = {"this": True, "expression": True, "requires_int128": False} class BitwiseXor(Binary): diff --git a/sqlglot/typing/__init__.py b/sqlglot/typing/__init__.py index 86af31b949..1dc708de1b 100644 --- a/sqlglot/typing/__init__.py +++ b/sqlglot/typing/__init__.py @@ -42,6 +42,8 @@ for expr_type in { exp.FromBase32, exp.FromBase64, + exp.HexString, + exp.Unhex, } }, **{ diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 03f3f9373f..2a6524030e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -482,7 +482,7 @@ def test_hexadecimal_literal(self): "clickhouse": UnsupportedError, "databricks": "SELECT X'CC'", "drill": "SELECT 204", - "duckdb": "SELECT CAST(HEX(FROM_HEX('CC')) AS VARBINARY)", + "duckdb": "SELECT UNHEX('CC')", "hive": "SELECT 204", "mysql": "SELECT x'CC'", "oracle": "SELECT 204", @@ -503,7 +503,7 @@ def test_hexadecimal_literal(self): "clickhouse": UnsupportedError, "databricks": "SELECT X'0000CC'", "drill": "SELECT 204", - "duckdb": "SELECT CAST(HEX(FROM_HEX('0000CC')) AS VARBINARY)", + "duckdb": "SELECT UNHEX('0000CC')", "hive": "SELECT 204", "mysql": "SELECT x'0000CC'", "oracle": "SELECT 204", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0a05bdf4d7..7141052c37 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -1692,10 +1692,68 @@ def test_snowflake(self): self.validate_identity("SELECT BIT_XOR(a, b)", "SELECT BITXOR(a, b)") self.validate_identity("SELECT BIT_XOR(a, b, 'LEFT')", "SELECT BITXOR(a, b, 'LEFT')") - self.validate_identity("SELECT BITSHIFTLEFT(a, 1)") - self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)") - self.validate_identity("SELECT BITSHIFTRIGHT(a, 1)") - self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)") + # duckdb has an order of operations precedence issue with bitshift and bitwise operators + self.validate_all( + "SELECT BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8))", + write={"duckdb": "SELECT (CAST(5 AS INT128) << 16) | (CAST(3 AS INT128) << 8)"}, + ) + self.validate_all( + "SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))", + write={ + "snowflake": "SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))", + "duckdb": "SELECT (CAST(255 AS INT128) << 4) & (CAST(15 AS INT128) << 2)", + }, + ) + self.validate_all( + "SELECT BITSHIFTLEFT(255, 4)", + write={ + "snowflake": "SELECT BITSHIFTLEFT(255, 4)", + "duckdb": "SELECT CAST(255 AS INT128) << 4", + }, + ) + self.validate_all( + "SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)", + write={ + "snowflake": "SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)", + "duckdb": "SELECT CAST(CAST(CAST(255 AS BLOB) AS BIT) << 4 AS BLOB)", + }, + ) + self.validate_all( + "SELECT BITSHIFTLEFT(X'FF', 4)", + write={ + "snowflake": "SELECT BITSHIFTLEFT(x'FF', 4)", + "duckdb": "SELECT CAST(CAST(UNHEX('FF') AS BIT) << 4 AS BLOB)", + }, + ) + self.validate_all( + "SELECT BITSHIFTRIGHT(255, 4)", + write={ + "snowflake": "SELECT BITSHIFTRIGHT(255, 4)", + "duckdb": "SELECT CAST(255 AS INT128) >> 4", + }, + ) + self.validate_all( + "SELECT BITSHIFTRIGHT(X'FF', 4)", + write={ + "snowflake": "SELECT BITSHIFTRIGHT(x'FF', 4)", + "duckdb": "SELECT CAST(CAST(UNHEX('FF') AS BIT) >> 4 AS BLOB)", + }, + ) + self.validate_all( + "SELECT BITSHIFTLEFT(X'002A'::BINARY, 1)", + write={ + "snowflake": "SELECT BITSHIFTLEFT(CAST(x'002A' AS BINARY), 1)", + "duckdb": "SELECT CAST(CAST(CAST(UNHEX('002A') AS BLOB) AS BIT) << 1 AS BLOB)", + }, + ) + self.validate_all( + "SELECT BITSHIFTRIGHT(X'002A'::BINARY, 1)", + write={ + "snowflake": "SELECT BITSHIFTRIGHT(CAST(x'002A' AS BINARY), 1)", + "duckdb": "SELECT CAST(CAST(CAST(UNHEX('002A') AS BLOB) AS BIT) >> 1 AS BLOB)", + }, + ) + self.validate_all( "OCTET_LENGTH('A')", read={ @@ -2113,7 +2171,7 @@ def test_snowflake(self): "SELECT x'ABCD'", write={ "snowflake": "SELECT x'ABCD'", - "duckdb": "SELECT CAST(HEX(FROM_HEX('ABCD')) AS VARBINARY)", + "duckdb": "SELECT UNHEX('ABCD')", }, ) diff --git a/tests/fixtures/optimizer/annotate_types.sql b/tests/fixtures/optimizer/annotate_types.sql index c047337ad0..700312f1ad 100644 --- a/tests/fixtures/optimizer/annotate_types.sql +++ b/tests/fixtures/optimizer/annotate_types.sql @@ -186,3 +186,27 @@ FLOAT64; # dialect: bigquery CASE WHEN x < y THEN CAST(3.5 AS BIGNUMERIC) WHEN x > y THEN 3/10 ELSE 2 END; FLOAT64; + +# dialect: snowflake +BITSHIFTLEFT(255, 4); +INT; + +# dialect: snowflake +BITSHIFTRIGHT(1024, 2); +INT; + +# dialect: snowflake +BITSHIFTLEFT(X'FF', 4); +BINARY; + +# dialect: snowflake +BITSHIFTRIGHT(X'FF', 4); +BINARY; + +# dialect: snowflake +BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8)); +INT; + +# dialect: snowflake +BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2)); +INT;