diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 4fc09597e9..26f27c0e8d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -7441,6 +7441,36 @@ class MapFromEntries(Func): pass +class MapCat(Func): + arg_types = {"this": True, "expression": True} + + +class MapContainsKey(Func): + arg_types = {"this": True, "key": True} + + +class MapDelete(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class MapInsert(Func): + arg_types = {"this": True, "key": False, "value": True, "update_flag": False} + + +class MapKeys(Func): + pass + + +class MapPick(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class MapSize(Func): + pass + + # https://learn.microsoft.com/en-us/sql/t-sql/language-elements/scope-resolution-operator-transact-sql?view=sql-server-ver16 class ScopeResolution(Expression): arg_types = {"this": False, "expression": True} diff --git a/sqlglot/typing/snowflake.py b/sqlglot/typing/snowflake.py index 8345707977..c73a3f3cfb 100644 --- a/sqlglot/typing/snowflake.py +++ b/sqlglot/typing/snowflake.py @@ -243,6 +243,7 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp exp.ArrayConstructCompact, exp.ArrayUniqueAgg, exp.ArrayUnionAgg, + exp.MapKeys, exp.RegexpExtractAll, exp.Split, exp.StringToArray, @@ -289,6 +290,7 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp exp.BoolxorAgg, exp.EqualNull, exp.IsNullValue, + exp.MapContainsKey, exp.Search, exp.SearchIp, exp.ToBoolean, @@ -385,6 +387,7 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp exp.JarowinklerSimilarity, exp.Length, exp.Levenshtein, + exp.MapSize, exp.Minute, exp.RtrimmedLength, exp.Second, @@ -406,6 +409,15 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp exp.XMLGet, } }, + **{ + expr_type: {"returns": exp.DataType.Type.MAP} + for expr_type in { + exp.MapCat, + exp.MapDelete, + exp.MapInsert, + exp.MapPick, + } + }, **{ expr_type: {"returns": exp.DataType.Type.FILE} for expr_type in { diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 729e5409e8..8166b5d221 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -309,6 +309,15 @@ def test_snowflake(self): self.validate_identity("SELECT rename, replace") self.validate_identity("SELECT TIMEADD(HOUR, 2, CAST('09:05:03' AS TIME))") self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS MAP(VARCHAR, INT))") + self.validate_identity( + "SELECT MAP_CAT(CAST(col AS MAP(VARCHAR, VARCHAR)), CAST(col AS MAP(VARCHAR, VARCHAR)))" + ) + self.validate_identity("SELECT MAP_CONTAINS_KEY('k1', CAST(col AS MAP(VARCHAR, VARCHAR)))") + self.validate_identity("SELECT MAP_DELETE(CAST(col AS MAP(VARCHAR, VARCHAR)), 'k1')") + self.validate_identity("SELECT MAP_INSERT(CAST(col AS MAP(VARCHAR, VARCHAR)), 'b', '2')") + self.validate_identity("SELECT MAP_KEYS(CAST(col AS MAP(VARCHAR, VARCHAR)))") + self.validate_identity("SELECT MAP_PICK(CAST(col AS MAP(VARCHAR, VARCHAR)), 'a', 'c')") + self.validate_identity("SELECT MAP_SIZE(CAST(col AS MAP(VARCHAR, VARCHAR)))") self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS OBJECT(a CHAR NOT NULL))") self.validate_identity("SELECT CAST([1, 2, 3] AS ARRAY(INT))") self.validate_identity("SELECT CAST(obj AS OBJECT(x CHAR) RENAME FIELDS)") diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index c1b6812733..2e4bf0552f 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -2952,6 +2952,34 @@ VARCHAR; LTRIM(NULL); VARCHAR; +# dialect: snowflake +MAP_CAT(CAST(col AS MAP(VARCHAR, VARCHAR)), CAST(col AS MAP(VARCHAR, VARCHAR))); +MAP; + +# dialect: snowflake +MAP_CONTAINS_KEY('k1', CAST(col AS MAP(VARCHAR, VARCHAR))); +BOOLEAN; + +# dialect: snowflake +MAP_DELETE(CAST(col AS MAP(VARCHAR, VARCHAR)), 'b'); +MAP; + +# dialect: snowflake +MAP_INSERT(CAST(col AS MAP(VARCHAR, VARCHAR)), 'b', '2'); +MAP; + +# dialect: snowflake +MAP_KEYS(CAST(col AS MAP(VARCHAR, VARCHAR))); +ARRAY; + +# dialect: snowflake +MAP_PICK(CAST(col AS MAP(VARCHAR, VARCHAR)), 'a', 'c'); +MAP; + +# dialect: snowflake +MAP_SIZE(CAST(col AS MAP(VARCHAR, VARCHAR))); +INT; + # dialect: snowflake MINUTE(CAST('08:50:57' AS TIME)); INT;