diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 99225ef85d..ba61cd1649 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -545,6 +545,9 @@ class Parser(parser.Parser): "EDIT_DISTANCE": _build_levenshtein, "FORMAT_DATE": _build_format_time(exp.TsOrDsToDate), "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, + "GREATEST": lambda args: exp.Greatest( + this=seq_get(args, 0), expressions=args[1:], return_null_if_any_null=True + ), "JSON_EXTRACT_SCALAR": _build_extract_json_with_default_path(exp.JSONExtractScalar), "JSON_EXTRACT_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray), "JSON_EXTRACT_STRING_ARRAY": _build_extract_json_with_default_path(exp.JSONValueArray), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 2df0f33963..564c0517e8 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -141,6 +141,11 @@ def _build_make_timestamp(args: t.List) -> exp.Expression: ) +def _build_greatest(args: t.List) -> exp.Greatest: + """Build GREATEST with all arguments properly distributed.""" + return exp.Greatest(this=seq_get(args, 0), expressions=args[1:]) + + def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[DuckDB.Parser], exp.Show]: def _parse(self: DuckDB.Parser) -> exp.Show: return self._parse_show_duckdb(*args, **kwargs) @@ -405,6 +410,29 @@ def _initcap_sql(self: DuckDB.Generator, expression: exp.Initcap) -> str: return _build_capitalization_sql(self, this_sql, escaped_delimiters_sql) +def _greatest_sql(self: DuckDB.Generator, expression: exp.Greatest) -> str: + """ + Handle GREATEST function with dialect-aware NULL behavior. + + - If return_null_if_any_null=True (BigQuery-style): return NULL if any argument is NULL + - If return_null_if_any_null=False (DuckDB/PostgreSQL-style): ignore NULLs, return greatest non-NULL value + """ + # Get all arguments + all_args = [expression.this] + (expression.expressions or []) + greatest_sql = self.func("GREATEST", *all_args) + + if expression.args.get("return_null_if_any_null"): + # BigQuery behavior: NULL if any argument is NULL + case_expr = exp.case().when( + exp.or_(*[arg.is_(exp.null()) for arg in all_args], copy=False), exp.null(), copy=False + ) + case_expr.set("default", greatest_sql) + return self.sql(case_expr) + + # DuckDB/PostgreSQL behavior: use native GREATEST (ignores NULLs) + return self.sql(greatest_sql) + + class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" SUPPORTS_USER_DEFINED_TYPES = True @@ -519,6 +547,7 @@ class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "GREATEST": _build_greatest, "ANY_VALUE": lambda args: exp.IgnoreNulls(this=exp.AnyValue.from_arg_list(args)), "ARRAY_REVERSE_SORT": _build_sort_array_desc, "ARRAY_SORT": exp.SortArray.from_arg_list, @@ -866,6 +895,7 @@ class Generator(generator.Generator): exp.EuclideanDistance: rename_func("LIST_DISTANCE"), exp.GenerateDateArray: _generate_datetime_array_sql, exp.GenerateTimestampArray: _generate_datetime_array_sql, + exp.Greatest: _greatest_sql, exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, within_group=False), exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 92ee5fac9a..cf1dcfceac 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6694,7 +6694,7 @@ class Getbit(Func): class Greatest(Func): - arg_types = {"this": True, "expressions": False} + arg_types = {"this": True, "expressions": False, "return_null_if_any_null": False} is_var_len_args = True diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 3faccc5732..08229ddb02 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2048,6 +2048,14 @@ def test_bigquery(self): }, ) + self.validate_all( + "SELECT GREATEST(1, NULL, 3)", + write={ + "duckdb": "SELECT CASE WHEN 1 IS NULL OR NULL IS NULL OR 3 IS NULL THEN NULL ELSE GREATEST(1, NULL, 3) END", + "bigquery": "SELECT GREATEST(1, NULL, 3)", + }, + ) + def test_errors(self): with self.assertRaises(ParseError): self.parse_one("SELECT * FROM a - b.c.d2") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 07e4fd18d8..d52d48c460 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1238,6 +1238,7 @@ def test_duckdb(self): self.validate_identity( "SELECT CAST(TRIM(CAST(CAST('***apple***' AS BLOB) AS TEXT), CAST(CAST('*' AS BLOB) AS TEXT)) AS BLOB) AS result" ) + self.validate_identity("SELECT GREATEST(1.0, 2.5, NULL, 3.7)") def test_array_index(self): with self.assertLogs(helper_logger) as cm: