Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#### Experimental Features

- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
- Added support for querying json element of a VARIANT column in `functions.col` and `functions.column` with an optional keyword argument `json_element`.
- Allow user input schema when reading JSON file on stage.
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
- `snowflake.core` is a dependency required for this feature.
Expand Down
30 changes: 28 additions & 2 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,27 @@ def __init__(
self,
expr1: Union[str, Expression],
expr2: Optional[str] = None,
json_element: bool = False,
_ast: Optional[proto.Expr] = None,
_emit_ast: bool = True,
) -> None:
self._ast = _ast

def derive_json_element_expr(
expr: str, df_alias: Optional[str] = None
) -> UnresolvedAttribute:
parts = expr.split(".")
if len(parts) == 1:
return UnresolvedAttribute(quote_name(parts[0]), df_alias=df_alias)
else:
# According to https://docs.snowflake.com/en/user-guide/querying-semistructured#dot-notation,
# the json value on the path should be case-sensitive
return UnresolvedAttribute(
f"{quote_name(parts[0])}:{'.'.join(quote_name(part, keep_case=True) for part in parts[1:])}",
is_sql_text=True,
df_alias=df_alias,
)

if expr2 is not None:
if not (isinstance(expr1, str) and isinstance(expr2, str)):
raise ValueError(
Expand All @@ -265,6 +281,8 @@ def __init__(

if expr2 == "*":
self._expression = Star([], df_alias=expr1)
elif json_element:
self._expression = derive_json_element_expr(expr2, expr1)
else:
self._expression = UnresolvedAttribute(
quote_name(expr2), df_alias=expr1
Expand All @@ -279,6 +297,8 @@ def __init__(
elif isinstance(expr1, str):
if expr1 == "*":
self._expression = Star([])
elif json_element:
self._expression = derive_json_element_expr(expr1)
else:
self._expression = UnresolvedAttribute(quote_name(expr1))

Expand Down Expand Up @@ -1483,9 +1503,15 @@ class CaseExpr(Column):
"""

def __init__(
self, expr: CaseWhen, _ast: Optional[proto.Expr] = None, _emit_ast: bool = True
self,
expr: CaseWhen,
json_element: bool = False,
_ast: Optional[proto.Expr] = None,
_emit_ast: bool = True,
) -> None:
super().__init__(expr, _ast=_ast, _emit_ast=_emit_ast)
super().__init__(
expr, json_element=json_element, _ast=_ast, _emit_ast=_emit_ast
)
self._branches = expr.branches

@publicapi
Expand Down
58 changes: 45 additions & 13 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,20 +262,31 @@ def _check_column_parameters(name1: str, name2: Optional[str]) -> None:

@overload
@publicapi
def col(col_name: str, _emit_ast: bool = True) -> Column:
def col(col_name: str, json_element: bool = False, _emit_ast: bool = True) -> Column:
"""Returns the :class:`~snowflake.snowpark.Column` with the specified name.

Args:
col_name: The name of the column.
json_element: Whether the column is a JSON element. If a column is a VARIANT column in Snowflake,
you can dot notation `.` to query the nested json element, e.g., "name.firstname" and "name.lastname".

Example::
>>> df = session.sql("select 1 as a")
>>> df.select(col("a")).collect()
[Row(A=1)]

>>> df = session.sql("select parse_json('{\"firstname\": \"John\", \"lastname\": \"Doe\"}') as name")
>>> df.select(col("name.firstname", json_element=True)).collect()
[Row(FIRSTNAME='John')]
"""
... # pragma: no cover


@overload
@publicapi
def col(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:
def col(
df_alias: str, col_name: str, json_element: bool = False, _emit_ast: bool = True
) -> Column:
"""Returns the :class:`~snowflake.snowpark.Column` with the specified dataframe alias and column name.

Example::
Expand All @@ -287,7 +298,12 @@ def col(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:


@publicapi
def col(name1: str, name2: Optional[str] = None, _emit_ast: bool = True) -> Column:
def col(
name1: str,
name2: Optional[str] = None,
json_element: bool = False,
_emit_ast: bool = True,
) -> Column:

_check_column_parameters(name1, name2)

Expand All @@ -296,27 +312,38 @@ def col(name1: str, name2: Optional[str] = None, _emit_ast: bool = True) -> Colu
ast = create_ast_for_column(name1, name2, "col")

if name2 is None:
return Column(name1, _ast=ast)
return Column(name1, json_element=json_element, _ast=ast)
else:
return Column(name1, name2, _ast=ast)
return Column(name1, name2, json_element=json_element, _ast=ast)


@overload
@publicapi
def column(col_name: str, _emit_ast: bool = True) -> Column:
def column(col_name: str, json_element: bool = False, _emit_ast: bool = True) -> Column:
"""Returns a :class:`~snowflake.snowpark.Column` with the specified name. Alias for col.

Args:
col_name: The name of the column.
json_element: Whether the column is a JSON element. If a column is a VARIANT column in Snowflake,
you can dot notation `.` to query the nested json element, e.g., "name.firstname" and "name.lastname".

Example::
>>> df = session.sql("select 1 as a")
>>> df.select(column("a")).collect()
[Row(A=1)]
>>> df = session.sql("select 1 as a")
>>> df.select(column("a")).collect()
[Row(A=1)]

>>> df = session.sql("select parse_json('{\"firstname\": \"John\", \"lastname\": \"Doe\"}') as name")
>>> df.select(column("name.firstname", json_element=True)).collect()
[Row(FIRSTNAME='John')]
"""
... # pragma: no cover


@overload
@publicapi
def column(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:
def column(
df_alias: str, col_name: str, json_element: bool = False, _emit_ast: bool = True
) -> Column:
"""Returns a :class:`~snowflake.snowpark.Column` with the specified name and dataframe alias name. Alias for col.

Example::
Expand All @@ -328,15 +355,20 @@ def column(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:


@publicapi
def column(name1: str, name2: Optional[str] = None, _emit_ast: bool = True) -> Column:
def column(
name1: str,
name2: Optional[str] = None,
json_element: bool = False,
_emit_ast: bool = True,
) -> Column:
_check_column_parameters(name1, name2)

ast = create_ast_for_column(name1, name2, "column") if _emit_ast else None

if name2 is None:
return Column(name1, _ast=ast)
return Column(name1, json_element=json_element, _ast=ast)
else:
return Column(name1, name2, _ast=ast)
return Column(name1, name2, json_element=json_element, _ast=ast)


@publicapi
Expand Down
47 changes: 47 additions & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,53 @@
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="querying json element is not supported in local testing",
)
def test_col_json_element(session):
# 2-level deep
df = session.sql(
'select parse_json(\'{"firstname": "John", "lastname": "Doe"}\') as name'
)
Utils.check_answer(
df.select(
col("name.firstname", json_element=True),
col("name.lastname", json_element=True),
),
[Row('"John"', '"Doe"')],
)
Utils.check_answer(
df.select(
col('name."firstname"', json_element=True),
col('NAME."lastname"', json_element=True),
),
[Row('"John"', '"Doe"')],
)
Utils.check_answer(df.select(col("name.FIRSTNAME", json_element=True)), [Row(None)])

# 3-level deep
with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("name:firstname", json_element=True)).collect()

with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("name.firstname")).collect()

df = session.sql('select parse_json(\'{"l1": {"l2": "xyz"}}\') as value')
Utils.check_answer(df.select(col("value.l1.l2", json_element=True)), Row('"xyz"'))
Utils.check_answer(
df.select(col('value."l1"."l2"', json_element=True)), Row('"xyz"')
)
Utils.check_answer(df.select(col("value.L1.l2", json_element=True)), Row(None))
Utils.check_answer(df.select(col("value.l1.L2", json_element=True)), Row(None))

with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("value:l1.l2", json_element=True)).collect()

with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("value.l1.l2")).collect()


def test_order(session):
null_data1 = TestData.null_data1(session)
assert null_data1.sort(asc(null_data1["A"])).collect() == [
Expand Down
Loading