Skip to content

Commit 400ea54

Browse files
authored
Feat!: ensure JSON_FORMAT type is JSON when targeting Presto (#4968)
1 parent 7bc5a21 commit 400ea54

File tree

16 files changed

+66
-30
lines changed

16 files changed

+66
-30
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
800800
if unnest_expr:
801801
from sqlglot.optimizer.annotate_types import annotate_types
802802

803-
unnest_expr = annotate_types(unnest_expr)
803+
unnest_expr = annotate_types(unnest_expr, dialect=self.dialect)
804804

805805
# Unnesting a nested array (i.e array of structs) explodes the top-level struct fields,
806806
# in contrast to other dialects such as DuckDB which flattens only the array by default
@@ -1227,7 +1227,7 @@ def bracket_sql(self, expression: exp.Bracket) -> str:
12271227
if arg.type is None:
12281228
from sqlglot.optimizer.annotate_types import annotate_types
12291229

1230-
arg = annotate_types(arg)
1230+
arg = annotate_types(arg, dialect=self.dialect)
12311231

12321232
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
12331233
# BQ doesn't support bracket syntax with string values for structs

sqlglot/dialects/dialect.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,9 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
13031303
if not zone:
13041304
from sqlglot.optimizer.annotate_types import annotate_types
13051305

1306-
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
1306+
target_type = (
1307+
annotate_types(expression, dialect=self.dialect).type or exp.DataType.Type.TIMESTAMP
1308+
)
13071309
return self.sql(exp.cast(expression.this, target_type))
13081310
if zone.name.lower() in TIMEZONES:
13091311
return self.sql(
@@ -1870,7 +1872,7 @@ def build_timetostr_or_tochar(args: t.List, dialect: Dialect) -> exp.TimeToStr |
18701872
if this and not this.type:
18711873
from sqlglot.optimizer.annotate_types import annotate_types
18721874

1873-
annotate_types(this)
1875+
annotate_types(this, dialect=dialect)
18741876
if this.is_type(*exp.DataType.TEMPORAL_TYPES):
18751877
dialect_name = dialect.__class__.__name__.lower()
18761878
return build_formatted_time(exp.TimeToStr, dialect_name, default=True)(args)

sqlglot/dialects/duckdb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ def bracket_sql(self, expression: exp.Bracket) -> str:
10081008
if not this.type:
10091009
from sqlglot.optimizer.annotate_types import annotate_types
10101010

1011-
this = annotate_types(this)
1011+
this = annotate_types(this, dialect=self.dialect)
10121012

10131013
if this.is_type(exp.DataType.Type.MAP):
10141014
bracket = f"({bracket})[1]"
@@ -1042,7 +1042,7 @@ def length_sql(self, expression: exp.Length) -> str:
10421042
if not arg.type:
10431043
from sqlglot.optimizer.annotate_types import annotate_types
10441044

1045-
arg = annotate_types(arg)
1045+
arg = annotate_types(arg, dialect=self.dialect)
10461046

10471047
if arg.is_type(*exp.DataType.TEXT_TYPES):
10481048
return self.func("LENGTH", arg)

sqlglot/dialects/postgres.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
from sqlglot.parser import binary_range_parser
4242
from sqlglot.tokens import TokenType
4343

44+
if t.TYPE_CHECKING:
45+
from sqlglot.dialects.dialect import DialectType
46+
4447

4548
DATE_DIFF_FACTOR = {
4649
"MICROSECOND": " * 1000000",
@@ -191,7 +194,7 @@ def _generate(self: Postgres.Generator, expression: JSON_EXTRACT_TYPE) -> str:
191194
return _generate
192195

193196

194-
def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
197+
def _build_regexp_replace(args: t.List, dialect: DialectType = None) -> exp.RegexpReplace:
195198
# The signature of REGEXP_REPLACE is:
196199
# regexp_replace(source, pattern, replacement [, start [, N ]] [, flags ])
197200
#
@@ -204,7 +207,7 @@ def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
204207
if not last.type or last.is_type(exp.DataType.Type.UNKNOWN, exp.DataType.Type.NULL):
205208
from sqlglot.optimizer.annotate_types import annotate_types
206209

207-
last = annotate_types(last)
210+
last = annotate_types(last, dialect=dialect)
208211

209212
if last.is_type(*exp.DataType.TEXT_TYPES):
210213
regexp_replace = exp.RegexpReplace.from_arg_list(args[:-1])
@@ -657,7 +660,7 @@ def unnest_sql(self, expression: exp.Unnest) -> str:
657660

658661
from sqlglot.optimizer.annotate_types import annotate_types
659662

660-
this = annotate_types(arg)
663+
this = annotate_types(arg, dialect=self.dialect)
661664
if this.is_type("array<json>"):
662665
while isinstance(this, exp.Cast):
663666
this = this.this

sqlglot/dialects/presto.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ class Parser(parser.Parser):
332332
"FROM_UTF8": lambda args: exp.Decode(
333333
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
334334
),
335+
"JSON_FORMAT": lambda args: exp.JSONFormat(
336+
this=seq_get(args, 0), options=seq_get(args, 1), is_json=True
337+
),
335338
"LEVENSHTEIN_DISTANCE": exp.Levenshtein.from_arg_list,
336339
"NOW": exp.CurrentTimestamp.from_arg_list,
337340
"REGEXP_EXTRACT": build_regexp_extract(exp.RegexpExtract),
@@ -582,13 +585,27 @@ class Generator(generator.Generator):
582585
"with",
583586
}
584587

588+
def jsonformat_sql(self, expression: exp.JSONFormat) -> str:
589+
this = expression.this
590+
is_json = expression.args.get("is_json")
591+
592+
if this and not (is_json or this.type):
593+
from sqlglot.optimizer.annotate_types import annotate_types
594+
595+
this = annotate_types(this, dialect=self.dialect)
596+
597+
if not (is_json or this.is_type(exp.DataType.Type.JSON)):
598+
this.replace(exp.cast(this, exp.DataType.Type.JSON))
599+
600+
return self.function_fallback_sql(expression)
601+
585602
def md5_sql(self, expression: exp.MD5) -> str:
586603
this = expression.this
587604

588605
if not this.type:
589606
from sqlglot.optimizer.annotate_types import annotate_types
590607

591-
this = annotate_types(this)
608+
this = annotate_types(this, dialect=self.dialect)
592609

593610
if this.is_type(*exp.DataType.TEXT_TYPES):
594611
this = exp.Encode(this=this, charset=exp.Literal.string("utf-8"))
@@ -630,6 +647,7 @@ def bracket_sql(self, expression: exp.Bracket) -> str:
630647
expression.this,
631648
expression.expressions,
632649
1 - expression.args.get("offset", 0),
650+
dialect=self.dialect,
633651
),
634652
0,
635653
),
@@ -639,7 +657,7 @@ def bracket_sql(self, expression: exp.Bracket) -> str:
639657
def struct_sql(self, expression: exp.Struct) -> str:
640658
from sqlglot.optimizer.annotate_types import annotate_types
641659

642-
expression = annotate_types(expression)
660+
expression = annotate_types(expression, dialect=self.dialect)
643661
values: t.List[str] = []
644662
schema: t.List[str] = []
645663
unknown_type = False

sqlglot/dialects/snowflake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def trycast_sql(self, expression: exp.TryCast) -> str:
11721172
if value.type is None:
11731173
from sqlglot.optimizer.annotate_types import annotate_types
11741174

1175-
value = annotate_types(value)
1175+
value = annotate_types(value, dialect=self.dialect)
11761176

11771177
if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN):
11781178
return super().trycast_sql(expression)

sqlglot/executor/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def execute(
6565

6666
for column in table.columns:
6767
value = table[0][column]
68-
column_type = annotate_types(exp.convert(value)).type or type(value).__name__
68+
column_type = (
69+
annotate_types(exp.convert(value), dialect=read).type or type(value).__name__
70+
)
6971
nested_set(schema, [*keys, column], column_type)
7072

7173
schema = ensure_schema(schema, dialect=read)

sqlglot/expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6306,7 +6306,7 @@ class JSONBExtractScalar(Binary, Func):
63066306

63076307

63086308
class JSONFormat(Func):
6309-
arg_types = {"this": False, "options": False}
6309+
arg_types = {"this": False, "options": False, "is_json": False}
63106310
_sql_names = ["JSON_FORMAT"]
63116311

63126312

sqlglot/generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2807,6 +2807,7 @@ def bracket_offset_expressions(
28072807
expression.this,
28082808
expression.expressions,
28092809
(index_offset or self.dialect.INDEX_OFFSET) - expression.args.get("offset", 0),
2810+
dialect=self.dialect,
28102811
)
28112812

28122813
def bracket_sql(self, expression: exp.Bracket) -> str:
@@ -4018,7 +4019,7 @@ def toarray_sql(self, expression: exp.ToArray) -> str:
40184019
if not arg.type:
40194020
from sqlglot.optimizer.annotate_types import annotate_types
40204021

4021-
arg = annotate_types(arg)
4022+
arg = annotate_types(arg, dialect=self.dialect)
40224023

40234024
if arg.is_type(exp.DataType.Type.ARRAY):
40244025
return self.sql(arg)

sqlglot/helper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
if t.TYPE_CHECKING:
1616
from sqlglot import exp
1717
from sqlglot._typing import A, E, T
18+
from sqlglot.dialects.dialect import DialectType
1819
from sqlglot.expressions import Expression
1920

2021

@@ -150,6 +151,7 @@ def apply_index_offset(
150151
this: exp.Expression,
151152
expressions: t.List[E],
152153
offset: int,
154+
dialect: DialectType = None,
153155
) -> t.List[E]:
154156
"""
155157
Applies an offset to a given integer literal expression.
@@ -158,6 +160,7 @@ def apply_index_offset(
158160
this: The target of the index.
159161
expressions: The expression the offset will be applied to, wrapped in a list.
160162
offset: The offset that will be applied.
163+
dialect: the dialect of interest.
161164
162165
Returns:
163166
The original expression with the offset applied to it, wrapped in a list. If the provided
@@ -173,7 +176,7 @@ def apply_index_offset(
173176
from sqlglot.optimizer.simplify import simplify
174177

175178
if not this.type:
176-
annotate_types(this)
179+
annotate_types(this, dialect=dialect)
177180

178181
if t.cast(exp.DataType, this.type).this not in (
179182
exp.DataType.Type.UNKNOWN,
@@ -182,7 +185,7 @@ def apply_index_offset(
182185
return expressions
183186

184187
if not expression.type:
185-
annotate_types(expression)
188+
annotate_types(expression, dialect=dialect)
186189

187190
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
188191
logger.info("Applying array index offset (%s)", offset)

0 commit comments

Comments
 (0)