Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,42 @@ def _boolxor_agg_sql(self: DuckDB.Generator, expression: exp.BoolxorAgg) -> str:
)


def _bitshift_sql(
self: DuckDB.Generator, expression: exp.BitwiseLeftShift | exp.BitwiseRightShift
) -> str:
"""
Transform bitshift expressions for DuckDB by injecting BIT/INT128 casts.

DuckDB's bitwise shift operators don't work with BLOB/BINARY types, so we cast
them to BIT for the operation, then cast the result back to the original type.

Note: Assumes type annotation has been applied with the source dialect.
"""
operator = "<<" if isinstance(expression, exp.BitwiseLeftShift) else ">>"
original_type = None
this = expression.this

# Deal with binary separately, remember the original type, cast back later
if _is_binary(this):
original_type = this.to if isinstance(this, exp.Cast) else exp.DataType.build("BLOB")
expression.set("this", _cast_to_bit(this))
elif expression.args.get("requires_int128"):
this.replace(exp.cast(this, exp.DataType.Type.INT128))

result_sql = self.binary(expression, operator)

# Wrap in parentheses if parent is a bitwise operator to "fix" DuckDB precedence issue
# DuckDB parses: a << b | c << d as (a << b | c) << d
if isinstance(expression.parent, exp.Binary):
result_sql = self.sql(exp.Paren(this=result_sql))

# Cast the result back to the original type
if original_type:
result_sql = self.sql(exp.Cast(this=result_sql, to=original_type))

return result_sql


def _scale_rounding_sql(
self: DuckDB.Generator,
expression: exp.Expression,
Expand Down Expand Up @@ -1230,8 +1266,10 @@ class Generator(generator.Generator):
),
exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"),
exp.BitwiseAndAgg: _bitwise_agg_sql,
exp.BitwiseLeftShift: _bitshift_sql,
exp.BitwiseOr: lambda self, e: self._bitwise_op(e, "|"),
exp.BitwiseOrAgg: _bitwise_agg_sql,
exp.BitwiseRightShift: _bitshift_sql,
exp.BitwiseXorAgg: _bitwise_agg_sql,
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),
Expand Down
11 changes: 10 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,16 @@ def _builder(args: t.List) -> B | exp.Anonymous:
)
return exp.Anonymous(this=name, expressions=args)

return binary_from_function(expr_type)(args)
result = binary_from_function(expr_type)(args)

# Snowflake specifies INT128 for bitwise shifts
if expr_type in (exp.BitwiseLeftShift, exp.BitwiseRightShift):
result.set("requires_int128", True)
# Mark HexStrings as integers for proper rendering in target dialects
for hexstr in result.find_all(exp.HexString):
hexstr.set("is_integer", True)

return result

return _builder

Expand Down
4 changes: 2 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5143,15 +5143,15 @@ class BitwiseAnd(Binary):


class BitwiseLeftShift(Binary):
pass
arg_types = {"this": True, "expression": True, "requires_int128": False}


class BitwiseOr(Binary):
arg_types = {"this": True, "expression": True, "padside": False}


class BitwiseRightShift(Binary):
pass
arg_types = {"this": True, "expression": True, "requires_int128": False}


class BitwiseXor(Binary):
Expand Down
59 changes: 55 additions & 4 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,10 +1701,61 @@ def test_snowflake(self):
self.validate_identity("SELECT BIT_XOR(a, b)", "SELECT BITXOR(a, b)")
self.validate_identity("SELECT BIT_XOR(a, b, 'LEFT')", "SELECT BITXOR(a, b, 'LEFT')")

self.validate_identity("SELECT BITSHIFTLEFT(a, 1)")
self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)")
self.validate_identity("SELECT BITSHIFTRIGHT(a, 1)")
self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)")
# duckdb has an order of operations precedence issue with bitshift and bitwise operators
self.validate_all(
"SELECT BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8))",
write={"duckdb": "SELECT (CAST(5 AS INT128) << 16) | (CAST(3 AS INT128) << 8)"},
)
self.validate_all(
"SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))",
write={
"snowflake": "SELECT BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2))",
"duckdb": "SELECT (CAST(255 AS INT128) << 4) & (CAST(15 AS INT128) << 2)",
},
)
self.validate_all(
"SELECT BITSHIFTLEFT(255, 4)",
write={
"snowflake": "SELECT BITSHIFTLEFT(255, 4)",
"duckdb": "SELECT CAST(255 AS INT128) << 4",
},
)
self.validate_all(
"SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)",
write={
"snowflake": "SELECT BITSHIFTLEFT(CAST(255 AS BINARY), 4)",
"duckdb": "SELECT CAST(CAST(CAST(255 AS BLOB) AS BIT) << 4 AS BLOB)",
},
)
self.validate_all(
"SELECT BITSHIFTLEFT(X'FF', 4)",
write={
"snowflake": "SELECT BITSHIFTLEFT(255, 4)",
"duckdb": "SELECT CAST(255 AS INT128) << 4",
},
)
self.validate_all(
"SELECT BITSHIFTRIGHT(255, 4)",
write={
"snowflake": "SELECT BITSHIFTRIGHT(255, 4)",
"duckdb": "SELECT CAST(255 AS INT128) >> 4",
},
)
self.validate_all(
"SELECT BITSHIFTRIGHT(CAST(255 AS BINARY), 4)",
write={
"snowflake": "SELECT BITSHIFTRIGHT(CAST(255 AS BINARY), 4)",
"duckdb": "SELECT CAST(CAST(CAST(255 AS BLOB) AS BIT) >> 4 AS BLOB)",
},
)
self.validate_all(
"SELECT BITSHIFTRIGHT(X'FF', 4)",
write={
"snowflake": "SELECT BITSHIFTRIGHT(255, 4)",
"duckdb": "SELECT CAST(255 AS INT128) >> 4",
},
)
Comment on lines +1751 to +1757
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generates a BINARY in Snowflake but an INT in DuckDB, this is intentional given that we don't annotate the types (?)

sf> SELECT system$typeof(BITSHIFTRIGHT(X'FF', 4)), BITSHIFTRIGHT(X'FF', 4);
BINARY[LOB] | 0F
-- | --


duckdb> SELECT CAST(255 AS INT128) >> 4;
┌─────────────────────────────┐
│ (CAST(255 AS HUGEINT) >> 4) │
│           int128            │
├─────────────────────────────┤
│             15              │
└─────────────────────────────┘

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we annotate the input w/ snowflake? I think preserving the type requires type inference to run first.


self.validate_all(
"OCTET_LENGTH('A')",
read={
Expand Down
32 changes: 32 additions & 0 deletions tests/fixtures/optimizer/annotate_types.sql
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,35 @@ FLOAT64;
# dialect: bigquery
CASE WHEN x < y THEN CAST(3.5 AS BIGNUMERIC) WHEN x > y THEN 3/10 ELSE 2 END;
FLOAT64;

# dialect: snowflake
BITSHIFTLEFT(255, 4);
INT;

# dialect: snowflake
BITSHIFTRIGHT(1024, 2);
INT;

# dialect: snowflake
BITSHIFTLEFT(CAST(255 AS BINARY), 4);
BINARY;

# dialect: snowflake
BITSHIFTRIGHT(CAST(255 AS BINARY), 4);
BINARY;

# dialect: snowflake
BITSHIFTLEFT(X'FF', 4);
BINARY;

# dialect: snowflake
BITSHIFTRIGHT(X'FF', 4);
BINARY;

# dialect: snowflake
BITOR(BITSHIFTLEFT(5, 16), BITSHIFTLEFT(3, 8));
INT;

# dialect: snowflake
BITAND(BITSHIFTLEFT(255, 4), BITSHIFTLEFT(15, 2));
INT;