Skip to content

Commit 60e26b8

Browse files
authored
Fix(hive)!: improve transpilability of GET_JSON_OBJECT by parsing json path (#4980)
1 parent cb20038 commit 60e26b8

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

sqlglot/dialects/databricks.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import typing as t
43

54
from sqlglot import exp, transforms, jsonpath
65
from sqlglot.dialects.dialect import (
@@ -13,13 +12,6 @@
1312
from sqlglot.tokens import TokenType
1413

1514

16-
def _build_json_extract(args: t.List) -> exp.JSONExtract:
17-
# Transform GET_JSON_OBJECT(expr, '$.<path>') -> expr:<path>
18-
this = args[0]
19-
path = args[1].name.lstrip("$.")
20-
return exp.JSONExtract(this=this, expression=path)
21-
22-
2315
def _jsonextract_sql(
2416
self: Databricks.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
2517
) -> str:
@@ -46,7 +38,6 @@ class Parser(Spark.Parser):
4638
"DATE_ADD": build_date_delta(exp.DateAdd),
4739
"DATEDIFF": build_date_delta(exp.DateDiff),
4840
"DATE_DIFF": build_date_delta(exp.DateDiff),
49-
"GET_JSON_OBJECT": _build_json_extract,
5041
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "databricks"),
5142
}
5243

sqlglot/dialects/hive.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ class Parser(parser.Parser):
307307
"FIRST": _build_with_ignore_nulls(exp.First),
308308
"FIRST_VALUE": _build_with_ignore_nulls(exp.FirstValue),
309309
"FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, "hive", True),
310-
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
310+
"GET_JSON_OBJECT": lambda args, dialect: exp.JSONExtractScalar(
311+
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
312+
),
311313
"LAST": _build_with_ignore_nulls(exp.Last),
312314
"LAST_VALUE": _build_with_ignore_nulls(exp.LastValue),
313315
"MAP": parser.build_var_map,

tests/dialects/test_databricks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ def test_databricks(self):
6868
"FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)",
6969
)
7070

71+
self.validate_all(
72+
"SELECT c1:item[1].price",
73+
read={
74+
"spark": "SELECT GET_JSON_OBJECT(c1, '$.item[1].price')",
75+
},
76+
write={
77+
"databricks": "SELECT c1:item[1].price",
78+
"spark": "SELECT GET_JSON_OBJECT(c1, '$.item[1].price')",
79+
},
80+
)
81+
82+
self.validate_all(
83+
"SELECT GET_JSON_OBJECT(c1, '$.item[1].price')",
84+
write={
85+
"databricks": "SELECT c1:item[1].price",
86+
"spark": "SELECT GET_JSON_OBJECT(c1, '$.item[1].price')",
87+
},
88+
)
7189
self.validate_all(
7290
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
7391
write={

0 commit comments

Comments
 (0)