diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d1890f81b2..f6eb683607 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -2451,6 +2451,21 @@ def hexstring_sql( return self.sql(to_hex) + def datetrunc_sql(self, expression: exp.DateTrunc) -> str: + unit = unit_to_str(expression) + date = expression.this + result = self.func("DATE_TRUNC", unit, date) + + if expression.args.get("input_type_preserved"): + if not date.type: + from sqlglot.optimizer.annotate_types import annotate_types + + date = annotate_types(date, dialect=self.dialect) + + if date.type and date.is_type(*exp.DataType.TEMPORAL_TYPES): + return self.sql(exp.Cast(this=result, to=date.type)) + return result + def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str: unit = unit_to_str(expression) zone = expression.args.get("zone") @@ -2465,7 +2480,27 @@ def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str: result_sql = self.func("DATE_TRUNC", unit, timestamp) return self.sql(exp.AtTimeZone(this=result_sql, zone=zone)) - return self.func("DATE_TRUNC", unit, timestamp) + result = self.func("DATE_TRUNC", unit, timestamp) + if expression.args.get("input_type_preserved"): + if not timestamp.type: + from sqlglot.optimizer.annotate_types import annotate_types + + timestamp = annotate_types(timestamp, dialect=self.dialect) + + if timestamp.type and timestamp.is_type( + exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ + ): + dummy_date = exp.Cast( + this=exp.Literal.string("1970-01-01"), + to=exp.DataType(this=exp.DataType.Type.DATE), + ) + date_time = exp.Add(this=dummy_date, expression=timestamp) + result = self.func("DATE_TRUNC", unit, date_time) + return self.sql(exp.Cast(this=result, to=timestamp.type)) + + if timestamp.type and timestamp.is_type(*exp.DataType.TEMPORAL_TYPES): + return self.sql(exp.Cast(this=result, to=timestamp.type)) + return result def trim_sql(self, expression: exp.Trim) -> str: expression.this.replace(_cast_to_varchar(expression.this)) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9432771c8a..0127cf6155 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -32,7 +32,7 @@ groupconcat_sql, ) from sqlglot.generator import unsupported_args -from sqlglot.helper import find_new_name, flatten, is_int, seq_get +from sqlglot.helper import find_new_name, flatten, is_date_unit, is_int, seq_get from sqlglot.optimizer.scope import build_scope, find_all_in_scope from sqlglot.tokens import TokenType from sqlglot.typing.snowflake import EXPRESSION_METADATA @@ -259,7 +259,13 @@ def _parse(self: Snowflake.Parser) -> exp.Show: def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: trunc = date_trunc_to_time(args) - trunc.set("unit", map_date_part(trunc.args["unit"])) + unit = map_date_part(trunc.args["unit"]) + trunc.set("unit", unit) + is_time_input = trunc.this.is_type(exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ) + if (isinstance(trunc, exp.TimestampTrunc) and is_date_unit(unit) or is_time_input) or ( + isinstance(trunc, exp.DateTrunc) and not is_date_unit(unit) + ): + trunc.set("input_type_preserved", True) return trunc diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index bcc74a86dd..ab8a4aa0a8 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6509,7 +6509,7 @@ class DateDiff(Func, TimeUnit): class DateTrunc(Func): - arg_types = {"unit": True, "this": True, "zone": False} + arg_types = {"unit": True, "this": True, "zone": False, "input_type_preserved": False} def __init__(self, **args): # Across most dialects it's safe to unabbreviate the unit (e.g. 'Q' -> 'QUARTER') except Oracle @@ -6669,7 +6669,7 @@ class TimestampDiff(Func, TimeUnit): class TimestampTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} + arg_types = {"this": True, "unit": True, "zone": False, "input_type_preserved": False} class TimeSlice(Func, TimeUnit): diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index f276954fb9..9ae7a4281a 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -2668,7 +2668,56 @@ def test_timestamps(self): self.validate_identity("DATEADD(y, 5, x)", "DATEADD(YEAR, 5, x)") self.validate_identity("DATE_PART(yyy, x)", "DATE_PART(YEAR, x)") self.validate_identity("DATE_TRUNC(yr, x)", "DATE_TRUNC('YEAR', x)") + self.validate_all( + "DATE_TRUNC('YEAR', CAST('2024-06-15' AS DATE))", + write={ + "snowflake": "DATE_TRUNC('YEAR', CAST('2024-06-15' AS DATE))", + "duckdb": "DATE_TRUNC('YEAR', CAST('2024-06-15' AS DATE))", + }, + ) + self.validate_all( + "DATE_TRUNC('HOUR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))", + write={ + "snowflake": "DATE_TRUNC('HOUR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))", + "duckdb": "DATE_TRUNC('HOUR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))", + }, + ) + # Snowflake's DATE_TRUNC return type matches type of the expresison + # DuckDB's DATE_TRUNC return type matches type of granularity part. + # In Snowflake --> DuckDB, DATE_TRUNC(date_part, timestamp) should be cast to timestamp to preserve Snowflake behavior. + self.validate_all( + "DATE_TRUNC(YEAR, TIMESTAMP '2026-01-01 00:00:00')", + write={ + "snowflake": "DATE_TRUNC('YEAR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))", + "duckdb": "CAST(DATE_TRUNC('YEAR', CAST('2026-01-01 00:00:00' AS TIMESTAMP)) AS TIMESTAMP)", + }, + ) + self.validate_all( + "DATE_TRUNC(MONTH, CAST('2024-06-15 14:23:45' AS TIMESTAMPTZ))", + write={ + "snowflake": "DATE_TRUNC('MONTH', CAST('2024-06-15 14:23:45' AS TIMESTAMPTZ))", + "duckdb": "CAST(DATE_TRUNC('MONTH', CAST('2024-06-15 14:23:45' AS TIMESTAMPTZ)) AS TIMESTAMPTZ)", + }, + ) + + # In Snowflake --> DuckDB, DATE_TRUNC(time_part, date) should be cast to date to preserve Snowflake behavior. + self.validate_all( + "DATE_TRUNC('HOUR', CAST('2026-01-01' AS DATE))", + write={ + "snowflake": "DATE_TRUNC('HOUR', CAST('2026-01-01' AS DATE))", + "duckdb": "CAST(DATE_TRUNC('HOUR', CAST('2026-01-01' AS DATE)) AS DATE)", + }, + ) + # DuckDB does not support DATE_TRUNC(time_part, time), so we add a dummy date to generate DATE_TRUNC(time_part, date) --> DATE in DuckDB + # Then it is casted to a time (HH:MM:SS) to match Snowflake. + self.validate_all( + "DATE_TRUNC('HOUR', CAST('14:23:45.123456' AS TIME))", + write={ + "snowflake": "DATE_TRUNC('HOUR', CAST('14:23:45.123456' AS TIME))", + "duckdb": "CAST(DATE_TRUNC('HOUR', CAST('1970-01-01' AS DATE) + CAST('14:23:45.123456' AS TIME)) AS TIME)", + }, + ) self.validate_identity("TO_DATE('12345')").assert_is(exp.Anonymous) self.validate_identity(