Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,48 @@ def _boolxor_agg_sql(self: DuckDB.Generator, expression: exp.BoolxorAgg) -> str:
)


def _bitshift_sql(
self: DuckDB.Generator, expression: exp.BitwiseLeftShift | exp.BitwiseRightShift
) -> str:
"""
Transform bitshift expressions for DuckDB by injecting BIT/INT128 casts.

DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast
them to BIT for the operation, then cast the result back to the original type.

Note: Assumes type annotation has been applied with the source dialect.
"""
operator = "<<" if isinstance(expression, exp.BitwiseLeftShift) else ">>"
original_type = None
this = expression.this

# Ensure type annotation is available for nested expressions
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types

this = annotate_types(this, dialect=self.dialect)

# Deal with binary separately, remember the original type, cast back later
if _is_binary(this):
original_type = this.to if isinstance(this, exp.Cast) else exp.DataType.build("BLOB")
expression.set("this", exp.cast(this, exp.DataType.Type.BIT))
elif expression.args.get("requires_int128"):
this.replace(exp.cast(this, exp.DataType.Type.INT128))

result_sql = self.binary(expression, operator)

# Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue
# DuckDB parses: a << b | c << d as (a << b | c) << d
if isinstance(expression.parent, exp.Binary):
result_sql = self.sql(exp.Paren(this=result_sql))

# Cast the result back to the original type
if original_type:
result_sql = self.sql(exp.Cast(this=result_sql, to=original_type))

return result_sql


def _scale_rounding_sql(
self: DuckDB.Generator,
expression: exp.Expression,
Expand Down Expand Up @@ -1358,8 +1400,10 @@ class Generator(generator.Generator):
),
exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"),
exp.BitwiseAndAgg: _bitwise_agg_sql,
exp.BitwiseLeftShift: _bitshift_sql,
exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"),
exp.BitwiseOrAgg: _bitwise_agg_sql,
exp.BitwiseRightShift: _bitshift_sql,
exp.BitwiseXorAgg: _bitwise_agg_sql,
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.Corr: lambda self, e: self._corr_sql(e),
Expand Down Expand Up @@ -2451,16 +2495,8 @@ def format_sql(self, expression: exp.Format) -> str:
def hexstring_sql(
self, expression: exp.HexString, binary_function_repr: t.Optional[str] = None
) -> str:
from_hex = super().hexstring_sql(expression, binary_function_repr="FROM_HEX")

if expression.args.get("is_integer"):
return from_hex

# `from_hex` has transpiled x'ABCD' (BINARY) to DuckDB's '\xAB\xCD' (BINARY)
# `to_hex` & CASTing transforms it to "ABCD" (BINARY) to match representation
to_hex = exp.cast(self.func("TO_HEX", from_hex), exp.DataType.Type.BLOB)

return self.sql(to_hex)
# UNHEX('FF') correctly produces blob \xFF in DuckDB
return super().hexstring_sql(expression, binary_function_repr="UNHEX")

def datetrunc_sql(self, expression: exp.DateTrunc) -> str:
unit = unit_to_str(expression)
Expand Down
8 changes: 7 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,13 @@ def _builder(args: t.List) -> B | exp.Anonymous:
)
return exp.Anonymous(this=name, expressions=args)

return binary_from_function(expr_type)(args)
result = binary_from_function(expr_type)(args)

# Snowflake specifies INT128 for bitwise shifts
if expr_type in (exp.BitwiseLeftShift, exp.BitwiseRightShift):
result.set("requires_int128", True)

return result

return _builder

Expand Down
4 changes: 2 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5143,15 +5143,15 @@ class BitwiseAnd(Binary):


class BitwiseLeftShift(Binary):
pass
arg_types = {"this": True, "expression": True, "requires_int128": False}


class BitwiseOr(Binary):
arg_types = {"this": True, "expression": True, "padside": False}


class BitwiseRightShift(Binary):
pass
arg_types = {"this": True, "expression": True, "requires_int128": False}


class BitwiseXor(Binary):
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@
exp.Dot: {"annotator": lambda self, e: self._annotate_dot(e)},
exp.Explode: {"annotator": lambda self, e: self._annotate_explode(e)},
exp.Extract: {"annotator": lambda self, e: self._annotate_extract(e)},
exp.HexString: {
"annotator": lambda self, e: self._set_type(
e,
exp.DataType.Type.BIGINT if e.args.get("is_integer") else exp.DataType.Type.BINARY,
)
},
exp.GenerateSeries: {
"annotator": lambda self, e: self._annotate_by_args(e, "start", "end", "step", array=True)
},
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def test_hexadecimal_literal(self):
"clickhouse": UnsupportedError,
"databricks": "SELECT X'CC'",
"drill": "SELECT 204",
"duckdb": "SELECT CAST(HEX(FROM_HEX('CC')) AS VARBINARY)",
"duckdb": "SELECT UNHEX('CC')",
"hive": "SELECT 204",
"mysql": "SELECT x'CC'",
"oracle": "SELECT 204",
Expand All @@ -504,7 +504,7 @@ def test_hexadecimal_literal(self):
"clickhouse": UnsupportedError,
"databricks": "SELECT X'0000CC'",
"drill": "SELECT 204",
"duckdb": "SELECT CAST(HEX(FROM_HEX('0000CC')) AS VARBINARY)",
"duckdb": "SELECT UNHEX('0000CC')",
"hive": "SELECT 204",
"mysql": "SELECT x'0000CC'",
"oracle": "SELECT 204",
Expand Down
61 changes: 56 additions & 5 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,10 +1758,61 @@ def test_snowflake(self):
self.validate_identity("SELECT BIT_XOR(a, b)", "SELECT BITXOR(a, b)")
self.validate_identity("SELECT BIT_XOR(a, b, 'LEFT')", "SELECT BITXOR(a, b, 'LEFT')")

self.validate_identity("SELECT BITSHIFTLEFT(a, 1)")
self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)")
self.validate_identity("SELECT BITSHIFTRIGHT(a, 1)")
self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)")
# duckdb has an order of operations precedence issue with bitshift and bitwise operators
self.validate_all(
"SELECT BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8))",
write={"duckdb": "SELECT (CAST(5 AS INT128) << 16) | (CAST(3 AS INT128) << 8)"},
)
self.validate_all(
"SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))",
write={
"snowflake": "SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))",
"duckdb": "SELECT (CAST(255 AS INT128) << 4) & (CAST(15 AS INT128) << 2)",
},
)
self.validate_all(
"SELECT BITSHIFTLEFT(255, 4)",
write={
"snowflake": "SELECT BITSHIFTLEFT(255, 4)",
"duckdb": "SELECT CAST(255 AS INT128) << 4",
},
)
self.validate_all(
"SELECT BITSHIFTLEFT(X'FF', 4)",
write={
"snowflake": "SELECT BITSHIFTLEFT(x'FF', 4)",
"duckdb": "SELECT CAST(CAST(UNHEX('FF') AS BIT) << 4 AS BLOB)",
},
)
self.validate_all(
"SELECT BITSHIFTRIGHT(255, 4)",
write={
"snowflake": "SELECT BITSHIFTRIGHT(255, 4)",
"duckdb": "SELECT CAST(255 AS INT128) >> 4",
},
)
self.validate_all(
"SELECT BITSHIFTRIGHT(X'FF', 4)",
write={
"snowflake": "SELECT BITSHIFTRIGHT(x'FF', 4)",
"duckdb": "SELECT CAST(CAST(UNHEX('FF') AS BIT) >> 4 AS BLOB)",
},
)
self.validate_all(
"SELECT BITSHIFTLEFT(X'002A'::BINARY, 1)",
write={
"snowflake": "SELECT BITSHIFTLEFT(CAST(x'002A' AS BINARY), 1)",
"duckdb": "SELECT CAST(CAST(CAST(UNHEX('002A') AS BLOB) AS BIT) << 1 AS BLOB)",
},
)
self.validate_all(
"SELECT BITSHIFTRIGHT(X'002A'::BINARY, 1)",
write={
"snowflake": "SELECT BITSHIFTRIGHT(CAST(x'002A' AS BINARY), 1)",
"duckdb": "SELECT CAST(CAST(CAST(UNHEX('002A') AS BLOB) AS BIT) >> 1 AS BLOB)",
},
)

self.validate_all(
"OCTET_LENGTH('A')",
read={
Expand Down Expand Up @@ -2179,7 +2230,7 @@ def test_snowflake(self):
"SELECT x'ABCD'",
write={
"snowflake": "SELECT x'ABCD'",
"duckdb": "SELECT CAST(HEX(FROM_HEX('ABCD')) AS VARBINARY)",
"duckdb": "SELECT UNHEX('ABCD')",
},
)

Expand Down
24 changes: 24 additions & 0 deletions tests/fixtures/optimizer/annotate_types.sql
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ FLOAT64;
CASE WHEN x < y THEN CAST(3.5 AS BIGNUMERIC) WHEN x > y THEN 3/10 ELSE 2 END;
FLOAT64;

# dialect: snowflake
BITSHIFTLEFT(255, 4);
INT;

# dialect: snowflake
BITSHIFTRIGHT(1024, 2);
INT;

# dialect: snowflake
BITSHIFTLEFT(X'FF', 4);
BINARY;

# dialect: snowflake
BITSHIFTRIGHT(X'FF', 4);
BINARY;

# dialect: snowflake
BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8));
INT;

# dialect: snowflake
BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2));
INT;

# dialect: bigquery
CAST(1 AS BIGNUMERIC) + 1.5;
BIGNUMERIC;
Expand Down