Skip to content

Commit d46969d

Browse files
fivetran-MichaelLeeMichael Lee
andauthored
chore(optimizer)!: annotate snowflake ARRAY_APPEND and ARRAY_PREPEND (#6645)
* chore(optimizer)!: annotate snowflake ARRAY_APPEND and ARRAY_PREPEND * add support for ARRAY_PREPEND * add tests for spark --------- Co-authored-by: Michael Lee <[email protected]>
1 parent 2f8ffcf commit d46969d

File tree

7 files changed

+57
-0
lines changed

7 files changed

+57
-0
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ def _build_sort_array_desc(args: t.List) -> exp.Expression:
291291
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
292292

293293

294+
def _build_array_prepend(args: t.List) -> exp.Expression:
295+
return exp.ArrayPrepend(this=seq_get(args, 1), expression=seq_get(args, 0))
296+
297+
294298
def _build_date_diff(args: t.List) -> exp.Expression:
295299
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
296300

@@ -971,6 +975,7 @@ class Parser(parser.Parser):
971975
FUNCTIONS = {
972976
**parser.Parser.FUNCTIONS,
973977
"ANY_VALUE": lambda args: exp.IgnoreNulls(this=exp.AnyValue.from_arg_list(args)),
978+
"ARRAY_PREPEND": _build_array_prepend,
974979
"ARRAY_REVERSE_SORT": _build_sort_array_desc,
975980
"ARRAY_SORT": exp.SortArray.from_arg_list,
976981
"BIT_AND": exp.BitwiseAndAgg.from_arg_list,
@@ -995,12 +1000,14 @@ class Parser(parser.Parser):
9951000
"JSON": exp.ParseJSON.from_arg_list,
9961001
"JSON_EXTRACT_PATH": parser.build_extract_json_with_path(exp.JSONExtract),
9971002
"JSON_EXTRACT_STRING": parser.build_extract_json_with_path(exp.JSONExtractScalar),
1003+
"LIST_APPEND": exp.ArrayAppend.from_arg_list,
9981004
"LIST_CONTAINS": exp.ArrayContains.from_arg_list,
9991005
"LIST_COSINE_DISTANCE": exp.CosineDistance.from_arg_list,
10001006
"LIST_DISTANCE": exp.EuclideanDistance.from_arg_list,
10011007
"LIST_FILTER": exp.ArrayFilter.from_arg_list,
10021008
"LIST_HAS": exp.ArrayContains.from_arg_list,
10031009
"LIST_HAS_ANY": exp.ArrayOverlaps.from_arg_list,
1010+
"LIST_PREPEND": _build_array_prepend,
10041011
"LIST_REVERSE_SORT": _build_sort_array_desc,
10051012
"LIST_SORT": exp.SortArray.from_arg_list,
10061013
"LIST_TRANSFORM": exp.Transform.from_arg_list,
@@ -1279,9 +1286,11 @@ class Generator(generator.Generator):
12791286
[transforms.inherit_struct_field_names],
12801287
generator=inline_array_unless_query,
12811288
),
1289+
exp.ArrayAppend: rename_func("LIST_APPEND"),
12821290
exp.ArrayFilter: rename_func("LIST_FILTER"),
12831291
exp.ArrayRemove: remove_from_array_using_filter,
12841292
exp.ArraySort: _array_sort_sql,
1293+
exp.ArrayPrepend: lambda self, e: self.func("LIST_PREPEND", e.expression, e.this),
12851294
exp.ArraySum: rename_func("LIST_SUM"),
12861295
exp.ArrayUniqueAgg: lambda self, e: self.func(
12871296
"LIST", exp.Distinct(expressions=[e.this])

sqlglot/dialects/postgres.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,9 @@ class Parser(parser.Parser):
422422

423423
FUNCTIONS = {
424424
**parser.Parser.FUNCTIONS,
425+
"ARRAY_PREPEND": lambda args: exp.ArrayPrepend(
426+
this=seq_get(args, 1), expression=seq_get(args, 0)
427+
),
425428
"BIT_AND": exp.BitwiseAndAgg.from_arg_list,
426429
"BIT_OR": exp.BitwiseOrAgg.from_arg_list,
427430
"BIT_XOR": exp.BitwiseXorAgg.from_arg_list,
@@ -613,6 +616,7 @@ class Generator(generator.Generator):
613616
exp.AnyValue: _versioned_anyvalue_sql,
614617
exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
615618
exp.ArrayFilter: filter_array_using_unnest,
619+
exp.ArrayPrepend: lambda self, e: self.func("ARRAY_PREPEND", e.expression, e.this),
616620
exp.BitwiseAndAgg: rename_func("BIT_AND"),
617621
exp.BitwiseOrAgg: rename_func("BIT_OR"),
618622
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),

sqlglot/expressions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6047,6 +6047,14 @@ class ArrayAny(Func):
60476047
arg_types = {"this": True, "expression": True}
60486048

60496049

6050+
class ArrayAppend(Func):
6051+
arg_types = {"this": True, "expression": True}
6052+
6053+
6054+
class ArrayPrepend(Func):
6055+
arg_types = {"this": True, "expression": True}
6056+
6057+
60506058
class ArrayConcat(Func):
60516059
_sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"]
60526060
arg_types = {"this": True, "expressions": False}

sqlglot/typing/snowflake.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp
241241
exp.ApproxTopKEstimate,
242242
exp.Array,
243243
exp.ArrayAgg,
244+
exp.ArrayAppend,
244245
exp.ArrayConstructCompact,
246+
exp.ArrayPrepend,
245247
exp.ArrayUniqueAgg,
246248
exp.ArrayUnionAgg,
247249
exp.MapKeys,

tests/dialects/test_dialect.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,30 @@ def test_array(self):
14731473
},
14741474
)
14751475

1476+
self.validate_all(
1477+
"ARRAY_PREPEND(arr, x)",
1478+
read={
1479+
"duckdb": "LIST_PREPEND(x, arr)",
1480+
"postgres": "ARRAY_PREPEND(x, arr)",
1481+
},
1482+
write={
1483+
"duckdb": "LIST_PREPEND(x, arr)",
1484+
"postgres": "ARRAY_PREPEND(x, arr)",
1485+
"spark": "ARRAY_PREPEND(arr, x)",
1486+
"snowflake": "ARRAY_PREPEND(arr, x)",
1487+
},
1488+
)
1489+
1490+
self.validate_all(
1491+
"ARRAY_APPEND(arr, x)",
1492+
write={
1493+
"duckdb": "LIST_APPEND(arr, x)",
1494+
"postgres": "ARRAY_APPEND(arr, x)",
1495+
"spark": "ARRAY_APPEND(arr, x)",
1496+
"snowflake": "ARRAY_APPEND(arr, x)",
1497+
},
1498+
)
1499+
14761500
def test_order_by(self):
14771501
self.validate_identity(
14781502
"SELECT c FROM t ORDER BY a, b,",

tests/dialects/test_snowflake.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ def test_snowflake(self):
441441
self.validate_identity("SELECT TO_ARRAY(CAST(x AS ARRAY))")
442442
self.validate_identity("SELECT TO_ARRAY(CAST(['test'] AS VARIANT))")
443443
self.validate_identity("SELECT ARRAY_UNIQUE_AGG(x)")
444+
self.validate_identity("SELECT ARRAY_APPEND([1, 2, 3], 4)")
445+
self.validate_identity("SELECT ARRAY_PREPEND([2, 3, 4], 1)")
444446
self.validate_identity("SELECT AI_AGG(review, 'Summarize the reviews')")
445447
self.validate_identity("SELECT AI_SUMMARIZE_AGG(review)")
446448
self.validate_identity("SELECT AI_CLASSIFY('text', ['travel', 'cooking'])")

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,14 @@ ARRAY;
17441744
ARRAY_CONSTRUCT_COMPACT(1, null, 2);
17451745
ARRAY;
17461746

1747+
# dialect: snowflake
1748+
ARRAY_APPEND([1, 2, 3], 4);
1749+
ARRAY;
1750+
1751+
# dialect: snowflake
1752+
ARRAY_PREPEND([2, 3, 4], 1);
1753+
ARRAY;
1754+
17471755
# dialect: snowflake
17481756
ASIN(tbl.double_col);
17491757
DOUBLE;

0 commit comments

Comments
 (0)