Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,21 @@ def hexstring_sql(

return self.sql(to_hex)

def datetrunc_sql(self, expression: exp.DateTrunc) -> str:
unit = unit_to_str(expression)
date = expression.this
result = self.func("DATE_TRUNC", unit, date)

if expression.args.get("input_type_preserved"):
if not date.type:
from sqlglot.optimizer.annotate_types import annotate_types

date = annotate_types(date, dialect=self.dialect)

if date.type and date.is_type(*exp.DataType.TEMPORAL_TYPES):
return self.sql(exp.Cast(this=result, to=date.type))
return result

def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str:
unit = unit_to_str(expression)
zone = expression.args.get("zone")
Expand All @@ -2465,7 +2480,27 @@ def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str:
result_sql = self.func("DATE_TRUNC", unit, timestamp)
return self.sql(exp.AtTimeZone(this=result_sql, zone=zone))

return self.func("DATE_TRUNC", unit, timestamp)
result = self.func("DATE_TRUNC", unit, timestamp)
if expression.args.get("input_type_preserved"):
if not timestamp.type:
from sqlglot.optimizer.annotate_types import annotate_types

timestamp = annotate_types(timestamp, dialect=self.dialect)

if timestamp.type and timestamp.is_type(
exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ
):
dummy_date = exp.Cast(
this=exp.Literal.string("1970-01-01"),
to=exp.DataType(this=exp.DataType.Type.DATE),
)
date_time = exp.Add(this=dummy_date, expression=timestamp)
result = self.func("DATE_TRUNC", unit, date_time)
return self.sql(exp.Cast(this=result, to=timestamp.type))

if timestamp.type and timestamp.is_type(*exp.DataType.TEMPORAL_TYPES):
return self.sql(exp.Cast(this=result, to=timestamp.type))
return result

def trim_sql(self, expression: exp.Trim) -> str:
expression.this.replace(_cast_to_varchar(expression.this))
Expand Down
10 changes: 8 additions & 2 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
groupconcat_sql,
)
from sqlglot.generator import unsupported_args
from sqlglot.helper import find_new_name, flatten, is_int, seq_get
from sqlglot.helper import find_new_name, flatten, is_date_unit, is_int, seq_get
from sqlglot.optimizer.scope import build_scope, find_all_in_scope
from sqlglot.tokens import TokenType
from sqlglot.typing.snowflake import EXPRESSION_METADATA
Expand Down Expand Up @@ -259,7 +259,13 @@ def _parse(self: Snowflake.Parser) -> exp.Show:

def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
trunc = date_trunc_to_time(args)
trunc.set("unit", map_date_part(trunc.args["unit"]))
unit = map_date_part(trunc.args["unit"])
trunc.set("unit", unit)
is_time_input = trunc.this.is_type(exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ)
if (isinstance(trunc, exp.TimestampTrunc) and is_date_unit(unit) or is_time_input) or (
isinstance(trunc, exp.DateTrunc) and not is_date_unit(unit)
):
trunc.set("input_type_preserved", True)
return trunc


Expand Down
4 changes: 2 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6509,7 +6509,7 @@ class DateDiff(Func, TimeUnit):


class DateTrunc(Func):
arg_types = {"unit": True, "this": True, "zone": False}
arg_types = {"unit": True, "this": True, "zone": False, "input_type_preserved": False}

def __init__(self, **args):
# Across most dialects it's safe to unabbreviate the unit (e.g. 'Q' -> 'QUARTER') except Oracle
Expand Down Expand Up @@ -6669,7 +6669,7 @@ class TimestampDiff(Func, TimeUnit):


class TimestampTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
arg_types = {"this": True, "unit": True, "zone": False, "input_type_preserved": False}


class TimeSlice(Func, TimeUnit):
Expand Down
49 changes: 49 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2668,7 +2668,56 @@ def test_timestamps(self):
self.validate_identity("DATEADD(y, 5, x)", "DATEADD(YEAR, 5, x)")
self.validate_identity("DATE_PART(yyy, x)", "DATE_PART(YEAR, x)")
self.validate_identity("DATE_TRUNC(yr, x)", "DATE_TRUNC('YEAR', x)")
self.validate_all(
"DATE_TRUNC('YEAR', CAST('2024-06-15' AS DATE))",
write={
"snowflake": "DATE_TRUNC('YEAR', CAST('2024-06-15' AS DATE))",
"duckdb": "DATE_TRUNC('YEAR', CAST('2024-06-15' AS DATE))",
},
)
self.validate_all(
"DATE_TRUNC('HOUR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))",
write={
"snowflake": "DATE_TRUNC('HOUR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))",
"duckdb": "DATE_TRUNC('HOUR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))",
},
)
# Snowflake's DATE_TRUNC return type matches type of the expresison
# DuckDB's DATE_TRUNC return type matches type of granularity part.
# In Snowflake --> DuckDB, DATE_TRUNC(date_part, timestamp) should be cast to timestamp to preserve Snowflake behavior.
self.validate_all(
"DATE_TRUNC(YEAR, TIMESTAMP '2026-01-01 00:00:00')",
write={
"snowflake": "DATE_TRUNC('YEAR', CAST('2026-01-01 00:00:00' AS TIMESTAMP))",
"duckdb": "CAST(DATE_TRUNC('YEAR', CAST('2026-01-01 00:00:00' AS TIMESTAMP)) AS TIMESTAMP)",
},
)
self.validate_all(
"DATE_TRUNC(MONTH, CAST('2024-06-15 14:23:45' AS TIMESTAMPTZ))",
write={
"snowflake": "DATE_TRUNC('MONTH', CAST('2024-06-15 14:23:45' AS TIMESTAMPTZ))",
"duckdb": "CAST(DATE_TRUNC('MONTH', CAST('2024-06-15 14:23:45' AS TIMESTAMPTZ)) AS TIMESTAMPTZ)",
},
)

# In Snowflake --> DuckDB, DATE_TRUNC(time_part, date) should be cast to date to preserve Snowflake behavior.
self.validate_all(
"DATE_TRUNC('HOUR', CAST('2026-01-01' AS DATE))",
write={
"snowflake": "DATE_TRUNC('HOUR', CAST('2026-01-01' AS DATE))",
"duckdb": "CAST(DATE_TRUNC('HOUR', CAST('2026-01-01' AS DATE)) AS DATE)",
},
)

# DuckDB does not support DATE_TRUNC(time_part, time), so we add a dummy date to generate DATE_TRUNC(time_part, date) --> DATE in DuckDB
# Then it is casted to a time (HH:MM:SS) to match Snowflake.
self.validate_all(
"DATE_TRUNC('HOUR', CAST('14:23:45.123456' AS TIME))",
write={
"snowflake": "DATE_TRUNC('HOUR', CAST('14:23:45.123456' AS TIME))",
"duckdb": "CAST(DATE_TRUNC('HOUR', CAST('1970-01-01' AS DATE) + CAST('14:23:45.123456' AS TIME)) AS TIME)",
},
)
self.validate_identity("TO_DATE('12345')").assert_is(exp.Anonymous)

self.validate_identity(
Expand Down