diff --git a/sqlglot/dialects/singlestore.py b/sqlglot/dialects/singlestore.py index 4c8a51eb4b..756a1ea235 100644 --- a/sqlglot/dialects/singlestore.py +++ b/sqlglot/dialects/singlestore.py @@ -177,6 +177,10 @@ class Parser(MySQL.Parser): expression=seq_get(args, 0), json_type="JSON", ), + "JSON_KEYS": lambda args: exp.JSONKeys( + this=seq_get(args, 0), + expressions=args[1:], + ), "JSON_PRETTY": exp.JSONFormat.from_arg_list, "JSON_BUILD_ARRAY": lambda args: exp.JSONArray(expressions=args), "JSON_BUILD_OBJECT": lambda args: exp.JSONObject(expressions=args), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 730722eb60..5454443a31 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -828,6 +828,7 @@ class Parser(parser.Parser): "LOCALTIMESTAMP": exp.CurrentTimestamp.from_arg_list, "NULLIFZERO": _build_if_from_nullifzero, "OBJECT_CONSTRUCT": _build_object_construct, + "OBJECT_KEYS": exp.JSONKeys.from_arg_list, "OCTET_LENGTH": exp.ByteLength.from_arg_list, "PARSE_URL": lambda args: exp.ParseUrl( this=seq_get(args, 0), permissive=seq_get(args, 1) @@ -1563,6 +1564,7 @@ class Generator(generator.Generator): exp.JSONExtractScalar: lambda self, e: self.func( "JSON_EXTRACT_PATH_TEXT", e.this, e.expression ), + exp.JSONKeys: rename_func("OBJECT_KEYS"), exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), exp.JSONPathRoot: lambda *_: "", exp.JSONValueArray: _json_extract_value_array_sql, diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index ee0ff268cb..f1e07da527 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -141,6 +141,7 @@ class Parser(Spark2.Parser): "TRY_SUBTRACT": exp.SafeSubtract.from_arg_list, "DATEDIFF": _build_datediff, "DATE_DIFF": _build_datediff, + "JSON_OBJECT_KEYS": exp.JSONKeys.from_arg_list, "LISTAGG": exp.GroupConcat.from_arg_list, "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), @@ -222,6 +223,7 @@ class Generator(Spark2.Generator): exp.DatetimeSub: date_delta_to_binary_interval_op(cast=False), exp.GroupConcat: _groupconcat_sql, exp.EndsWith: rename_func("ENDSWITH"), + exp.JSONKeys: rename_func("JSON_OBJECT_KEYS"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.SafeAdd: rename_func("TRY_ADD"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 0ba933812d..b7cc66c382 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -7117,6 +7117,12 @@ class Format(Func): is_var_len_args = True +class JSONKeys(Func): + arg_types = {"this": True, "expression": False, "expressions": False} + is_var_len_args = True + _sql_names = ["JSON_KEYS"] + + class JSONKeyValue(Expression): arg_types = {"this": True, "expression": True} diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 7dbda7dfa3..4f1e245895 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -252,6 +252,9 @@ class Parser(metaclass=_Parser): "JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract), "JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar), "JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar), + "JSON_KEYS": lambda args, dialect: exp.JSONKeys( + this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) + ), "LIKE": build_like, "LOG": build_logarithm, "LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)), diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 4a887088a1..cda6c21eec 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -4964,3 +4964,46 @@ def test_operator(self): self.validate_identity("SELECT 1 OPERATOR(+) 2") self.validate_identity("SELECT 1 OPERATOR(+) /* foo */ 2") self.validate_identity("SELECT 1 OPERATOR(pg_catalog.+) 2") + + def test_json_keys(self): + self.validate_all( + "JSON_KEYS(foo)", + read={ + "": "JSON_KEYS(foo)", + "spark": "JSON_OBJECT_KEYS(foo)", + "databricks": "JSON_OBJECT_KEYS(foo)", + "mysql": "JSON_KEYS(foo)", + "starrocks": "JSON_KEYS(foo)", + "duckdb": "JSON_KEYS(foo)", + "snowflake": "OBJECT_KEYS(foo)", + "doris": "JSON_KEYS(foo)", + "singlestore": "JSON_KEYS(foo)", + }, + write={ + "spark": "JSON_OBJECT_KEYS(foo)", + "databricks": "JSON_OBJECT_KEYS(foo)", + "mysql": "JSON_KEYS(foo)", + "starrocks": "JSON_KEYS(foo)", + "duckdb": "JSON_KEYS(foo)", + "snowflake": "OBJECT_KEYS(foo)", + "doris": "JSON_KEYS(foo)", + "singlestore": "JSON_KEYS(foo)", + }, + ) + + self.validate_all( + "JSON_KEYS(foo, '$.a')", + read={ + "": "JSON_KEYS(foo, '$.a')", + "mysql": "JSON_KEYS(foo, '$.a')", + "starrocks": "JSON_KEYS(foo, '$.a')", + "duckdb": "JSON_KEYS(foo, '$.a')", + "doris": "JSON_KEYS(foo, '$.a')", + }, + write={ + "mysql": "JSON_KEYS(foo, '$.a')", + "starrocks": "JSON_KEYS(foo, '$.a')", + "duckdb": "JSON_KEYS(foo, '$.a')", + "doris": "JSON_KEYS(foo, '$.a')", + }, + ) diff --git a/tests/dialects/test_singlestore.py b/tests/dialects/test_singlestore.py index 3d71c05297..7ec6ba00fc 100644 --- a/tests/dialects/test_singlestore.py +++ b/tests/dialects/test_singlestore.py @@ -23,6 +23,8 @@ def test_singlestore(self): self.validate_identity("SELECT CHARSET(CHAR(100 USING utf8))") self.validate_identity("SELECT TO_JSON(ROW(1, 2) :> RECORD(a INT, b INT))") + self.validate_identity("JSON_KEYS(json_doc, 'a', 'b', 'c', 2)") + def test_byte_strings(self): self.validate_identity("SELECT e'text'") self.validate_identity("SELECT E'text'", "SELECT e'text'")