diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index cab31571e4..65fc9d000b 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -291,6 +291,10 @@ def _build_sort_array_desc(args: t.List) -> exp.Expression: return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) +def _build_array_prepend(args: t.List) -> exp.Expression: + return exp.ArrayPrepend(this=seq_get(args, 1), expression=seq_get(args, 0)) + + def _build_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) @@ -971,6 +975,7 @@ class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY_VALUE": lambda args: exp.IgnoreNulls(this=exp.AnyValue.from_arg_list(args)), + "ARRAY_PREPEND": _build_array_prepend, "ARRAY_REVERSE_SORT": _build_sort_array_desc, "ARRAY_SORT": exp.SortArray.from_arg_list, "BIT_AND": exp.BitwiseAndAgg.from_arg_list, @@ -995,12 +1000,14 @@ class Parser(parser.Parser): "JSON": exp.ParseJSON.from_arg_list, "JSON_EXTRACT_PATH": parser.build_extract_json_with_path(exp.JSONExtract), "JSON_EXTRACT_STRING": parser.build_extract_json_with_path(exp.JSONExtractScalar), + "LIST_APPEND": exp.ArrayAppend.from_arg_list, "LIST_CONTAINS": exp.ArrayContains.from_arg_list, "LIST_COSINE_DISTANCE": exp.CosineDistance.from_arg_list, "LIST_DISTANCE": exp.EuclideanDistance.from_arg_list, "LIST_FILTER": exp.ArrayFilter.from_arg_list, "LIST_HAS": exp.ArrayContains.from_arg_list, "LIST_HAS_ANY": exp.ArrayOverlaps.from_arg_list, + "LIST_PREPEND": _build_array_prepend, "LIST_REVERSE_SORT": _build_sort_array_desc, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_TRANSFORM": exp.Transform.from_arg_list, @@ -1279,9 +1286,11 @@ class Generator(generator.Generator): [transforms.inherit_struct_field_names], generator=inline_array_unless_query, ), + exp.ArrayAppend: rename_func("LIST_APPEND"), exp.ArrayFilter: rename_func("LIST_FILTER"), exp.ArrayRemove: remove_from_array_using_filter, exp.ArraySort: _array_sort_sql, + exp.ArrayPrepend: lambda self, e: self.func("LIST_PREPEND", e.expression, e.this), exp.ArraySum: rename_func("LIST_SUM"), exp.ArrayUniqueAgg: lambda self, e: self.func( "LIST", exp.Distinct(expressions=[e.this]) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index fd5b3666f2..5f042150a9 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -422,6 +422,9 @@ class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "ARRAY_PREPEND": lambda args: exp.ArrayPrepend( + this=seq_get(args, 1), expression=seq_get(args, 0) + ), "BIT_AND": exp.BitwiseAndAgg.from_arg_list, "BIT_OR": exp.BitwiseOrAgg.from_arg_list, "BIT_XOR": exp.BitwiseXorAgg.from_arg_list, @@ -613,6 +616,7 @@ class Generator(generator.Generator): exp.AnyValue: _versioned_anyvalue_sql, exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"), exp.ArrayFilter: filter_array_using_unnest, + exp.ArrayPrepend: lambda self, e: self.func("ARRAY_PREPEND", e.expression, e.this), exp.BitwiseAndAgg: rename_func("BIT_AND"), exp.BitwiseOrAgg: rename_func("BIT_OR"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 24f798ba7c..6564820f9d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6047,6 +6047,14 @@ class ArrayAny(Func): arg_types = {"this": True, "expression": True} +class ArrayAppend(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayPrepend(Func): + arg_types = {"this": True, "expression": True} + + class ArrayConcat(Func): _sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"] arg_types = {"this": True, "expressions": False} diff --git a/sqlglot/typing/snowflake.py b/sqlglot/typing/snowflake.py index 4620734bff..c090c5ba1c 100644 --- a/sqlglot/typing/snowflake.py +++ b/sqlglot/typing/snowflake.py @@ -241,7 +241,9 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp exp.ApproxTopKEstimate, exp.Array, exp.ArrayAgg, + exp.ArrayAppend, exp.ArrayConstructCompact, + exp.ArrayPrepend, exp.ArrayUniqueAgg, exp.ArrayUnionAgg, exp.MapKeys, diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 8e1c5e0fee..ab9ecf29c9 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1473,6 +1473,30 @@ def test_array(self): }, ) + self.validate_all( + "ARRAY_PREPEND(arr, x)", + read={ + "duckdb": "LIST_PREPEND(x, arr)", + "postgres": "ARRAY_PREPEND(x, arr)", + }, + write={ + "duckdb": "LIST_PREPEND(x, arr)", + "postgres": "ARRAY_PREPEND(x, arr)", + "spark": "ARRAY_PREPEND(arr, x)", + "snowflake": "ARRAY_PREPEND(arr, x)", + }, + ) + + self.validate_all( + "ARRAY_APPEND(arr, x)", + write={ + "duckdb": "LIST_APPEND(arr, x)", + "postgres": "ARRAY_APPEND(arr, x)", + "spark": "ARRAY_APPEND(arr, x)", + "snowflake": "ARRAY_APPEND(arr, x)", + }, + ) + def test_order_by(self): self.validate_identity( "SELECT c FROM t ORDER BY a, b,", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0a05bdf4d7..99258c6863 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -441,6 +441,8 @@ def test_snowflake(self): self.validate_identity("SELECT TO_ARRAY(CAST(x AS ARRAY))") self.validate_identity("SELECT TO_ARRAY(CAST(['test'] AS VARIANT))") self.validate_identity("SELECT ARRAY_UNIQUE_AGG(x)") + self.validate_identity("SELECT ARRAY_APPEND([1, 2, 3], 4)") + self.validate_identity("SELECT ARRAY_PREPEND([2, 3, 4], 1)") self.validate_identity("SELECT AI_AGG(review, 'Summarize the reviews')") self.validate_identity("SELECT AI_SUMMARIZE_AGG(review)") self.validate_identity("SELECT AI_CLASSIFY('text', ['travel', 'cooking'])") diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index ea7ba52b56..aa4f4de11d 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -1744,6 +1744,14 @@ ARRAY; ARRAY_CONSTRUCT_COMPACT(1, null, 2); ARRAY; +# dialect: snowflake +ARRAY_APPEND([1, 2, 3], 4); +ARRAY; + +# dialect: snowflake +ARRAY_PREPEND([2, 3, 4], 1); +ARRAY; + # dialect: snowflake ASIN(tbl.double_col); DOUBLE;