Skip to content

Commit 2c462a8

Browse files
committed
Fixups
1 parent 5bb1170 commit 2c462a8

File tree

3 files changed

+31
-34
lines changed

3 files changed

+31
-34
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,6 @@ def _build_make_timestamp(args: t.List) -> exp.Expression:
141141
)
142142

143143

144-
def _build_greatest(args: t.List) -> exp.Greatest:
145-
"""Build GREATEST with all arguments properly distributed."""
146-
return exp.Greatest(this=seq_get(args, 0), expressions=args[1:])
147-
148-
149144
def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[DuckDB.Parser], exp.Show]:
150145
def _parse(self: DuckDB.Parser) -> exp.Show:
151146
return self._parse_show_duckdb(*args, **kwargs)
@@ -410,29 +405,6 @@ def _initcap_sql(self: DuckDB.Generator, expression: exp.Initcap) -> str:
410405
return _build_capitalization_sql(self, this_sql, escaped_delimiters_sql)
411406

412407

413-
def _greatest_sql(self: DuckDB.Generator, expression: exp.Greatest) -> str:
414-
"""
415-
Handle GREATEST function with dialect-aware NULL behavior.
416-
417-
- If return_null_if_any_null=True (BigQuery-style): return NULL if any argument is NULL
418-
- If return_null_if_any_null=False (DuckDB/PostgreSQL-style): ignore NULLs, return greatest non-NULL value
419-
"""
420-
# Get all arguments
421-
all_args = [expression.this] + (expression.expressions or [])
422-
greatest_sql = self.func("GREATEST", *all_args)
423-
424-
if expression.args.get("return_null_if_any_null"):
425-
# BigQuery behavior: NULL if any argument is NULL
426-
case_expr = exp.case().when(
427-
exp.or_(*[arg.is_(exp.null()) for arg in all_args], copy=False), exp.null(), copy=False
428-
)
429-
case_expr.set("default", greatest_sql)
430-
return self.sql(case_expr)
431-
432-
# DuckDB/PostgreSQL behavior: use native GREATEST (ignores NULLs)
433-
return self.sql(greatest_sql)
434-
435-
436408
class DuckDB(Dialect):
437409
NULL_ORDERING = "nulls_are_last"
438410
SUPPORTS_USER_DEFINED_TYPES = True
@@ -547,7 +519,6 @@ class Parser(parser.Parser):
547519

548520
FUNCTIONS = {
549521
**parser.Parser.FUNCTIONS,
550-
"GREATEST": _build_greatest,
551522
"ANY_VALUE": lambda args: exp.IgnoreNulls(this=exp.AnyValue.from_arg_list(args)),
552523
"ARRAY_REVERSE_SORT": _build_sort_array_desc,
553524
"ARRAY_SORT": exp.SortArray.from_arg_list,
@@ -895,7 +866,6 @@ class Generator(generator.Generator):
895866
exp.EuclideanDistance: rename_func("LIST_DISTANCE"),
896867
exp.GenerateDateArray: _generate_datetime_array_sql,
897868
exp.GenerateTimestampArray: _generate_datetime_array_sql,
898-
exp.Greatest: _greatest_sql,
899869
exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, within_group=False),
900870
exp.Explode: rename_func("UNNEST"),
901871
exp.IntDiv: lambda self, e: self.binary(e, "//"),
@@ -1133,6 +1103,30 @@ class Generator(generator.Generator):
11331103
exp.NthValue,
11341104
)
11351105

1106+
def greatest_sql(self: DuckDB.Generator, expression: exp.Greatest) -> str:
1107+
"""
1108+
Handle GREATEST function with dialect-aware NULL behavior.
1109+
1110+
- If return_null_if_any_null=True (BigQuery-style): return NULL if any argument is NULL
1111+
- If return_null_if_any_null=False (DuckDB/PostgreSQL-style): ignore NULLs, return greatest non-NULL value
1112+
"""
1113+
# Get all arguments
1114+
all_args = [expression.this, *expression.expressions]
1115+
greatest_sql = self.function_fallback_sql(expression)
1116+
1117+
if expression.args.get("return_null_if_any_null"):
1118+
# BigQuery behavior: NULL if any argument is NULL
1119+
case_expr = exp.case().when(
1120+
exp.or_(*[arg.is_(exp.null()) for arg in all_args], copy=False),
1121+
exp.null(),
1122+
copy=False,
1123+
)
1124+
case_expr.set("default", greatest_sql)
1125+
return self.sql(case_expr)
1126+
1127+
# DuckDB/PostgreSQL behavior: use native GREATEST (ignores NULLs)
1128+
return self.sql(greatest_sql)
1129+
11361130
def lambda_sql(
11371131
self, expression: exp.Lambda, arrow_sep: str = "->", wrap: bool = True
11381132
) -> str:

sqlglot/parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class Parser(metaclass=_Parser):
230230
is_string=dialect.UUID_IS_STRING_TYPE or None
231231
),
232232
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
233+
"GREATEST": lambda args: exp.Greatest(this=seq_get(args, 0), expressions=args[1:]),
233234
"HEX": build_hex,
234235
"JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract),
235236
"JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar),

tests/test_expressions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,8 @@ def test_functions(self):
740740
self.assertIsInstance(parse_one("TIME_TO_TIME_STR(a)"), exp.Cast)
741741
self.assertIsInstance(parse_one("TIME_TO_UNIX(a)"), exp.TimeToUnix)
742742
self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate)
743-
(self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime),)
744-
self.assertIsInstance(
745-
parse_one("TIME_STR_TO_TIME(a, 'America/Los_Angeles')"), exp.TimeStrToTime
746-
)
743+
self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime)
744+
self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a, 'some_zone')"), exp.TimeStrToTime)
747745
self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix)
748746
self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim)
749747
self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd)
@@ -767,6 +765,10 @@ def test_functions(self):
767765
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)
768766
self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths)
769767

768+
ast = parse_one("GREATEST(a, b, c)")
769+
self.assertIsInstance(ast.expressions, list)
770+
self.assertEqual(len(ast.expressions), 2)
771+
770772
def test_column(self):
771773
column = exp.column(exp.Star(), table="t")
772774
self.assertEqual(column.sql(), "t.*")

0 commit comments

Comments
 (0)