Skip to content

Commit 4d3904e

Browse files
authored
fix(spark): Support DB's TIMESTAMP_DIFF (#4373)
1 parent 702fe31 commit 4d3904e

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

sqlglot/dialects/databricks.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
date_delta_sql,
88
build_date_delta,
99
timestamptrunc_sql,
10-
timestampdiff_sql,
1110
)
1211
from sqlglot.dialects.spark import Spark
1312
from sqlglot.tokens import TokenType
@@ -46,7 +45,6 @@ class Parser(Spark.Parser):
4645
"DATE_ADD": build_date_delta(exp.DateAdd),
4746
"DATEDIFF": build_date_delta(exp.DateDiff),
4847
"DATE_DIFF": build_date_delta(exp.DateDiff),
49-
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
5048
"GET_JSON_OBJECT": _build_json_extract,
5149
}
5250

@@ -75,8 +73,6 @@ class Generator(Spark.Generator):
7573
exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
7674
e.this,
7775
),
78-
exp.DatetimeDiff: timestampdiff_sql,
79-
exp.TimestampDiff: timestampdiff_sql,
8076
exp.DatetimeTrunc: timestamptrunc_sql(),
8177
exp.Select: transforms.preprocess(
8278
[

sqlglot/dialects/spark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing as t
44

55
from sqlglot import exp
6-
from sqlglot.dialects.dialect import rename_func, unit_to_var
6+
from sqlglot.dialects.dialect import rename_func, unit_to_var, timestampdiff_sql, build_date_delta
77
from sqlglot.dialects.hive import _build_with_ignore_nulls
88
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
99
from sqlglot.helper import ensure_list, seq_get
@@ -108,6 +108,7 @@ class Parser(Spark2.Parser):
108108
"DATE_ADD": _build_dateadd,
109109
"DATEADD": _build_dateadd,
110110
"TIMESTAMPADD": _build_dateadd,
111+
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
111112
"DATEDIFF": _build_datediff,
112113
"DATE_DIFF": _build_datediff,
113114
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
@@ -167,6 +168,8 @@ class Generator(Spark2.Generator):
167168
exp.StartsWith: rename_func("STARTSWITH"),
168169
exp.TsOrDsAdd: _dateadd_sql,
169170
exp.TimestampAdd: _dateadd_sql,
171+
exp.DatetimeDiff: timestampdiff_sql,
172+
exp.TimestampDiff: timestampdiff_sql,
170173
exp.TryCast: lambda self, e: (
171174
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
172175
),

tests/dialects/test_spark.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,17 @@ def test_spark(self):
754754
},
755755
)
756756

757+
self.validate_all(
758+
"SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
759+
read={
760+
"databricks": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
761+
},
762+
write={
763+
"spark": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
764+
"databricks": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
765+
},
766+
)
767+
757768
def test_bool_or(self):
758769
self.validate_all(
759770
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",

0 commit comments

Comments
 (0)