Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
- `try_to_binary`

- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
- Added support for querying json element of a VARIANT column in `functions.col` and `functions.column` with an optional keyword argument `json_element`.

#### Improvements

Expand Down
30 changes: 28 additions & 2 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,27 @@ def __init__(
self,
expr1: Union[str, Expression],
expr2: Optional[str] = None,
json_element: bool = False,
_ast: Optional[proto.Expr] = None,
_emit_ast: bool = True,
) -> None:
self._ast = _ast

def derive_json_element_expr(
expr: str, df_alias: Optional[str] = None
) -> UnresolvedAttribute:
parts = expr.split(".")
if len(parts) == 1:
return UnresolvedAttribute(quote_name(parts[0]), df_alias=df_alias)
else:
# According to https://docs.snowflake.com/en/user-guide/querying-semistructured#dot-notation,
# the json value on the path should be case-sensitive
return UnresolvedAttribute(
f"{quote_name(parts[0])}:{'.'.join(quote_name(part, keep_case=True) for part in parts[1:])}",
is_sql_text=True,
df_alias=df_alias,
)

if expr2 is not None:
if not (isinstance(expr1, str) and isinstance(expr2, str)):
raise ValueError(
Expand All @@ -265,6 +281,8 @@ def __init__(

if expr2 == "*":
self._expression = Star([], df_alias=expr1)
elif json_element:
self._expression = derive_json_element_expr(expr2, expr1)
else:
self._expression = UnresolvedAttribute(
quote_name(expr2), df_alias=expr1
Expand All @@ -279,6 +297,8 @@ def __init__(
elif isinstance(expr1, str):
if expr1 == "*":
self._expression = Star([])
elif json_element:
self._expression = derive_json_element_expr(expr1)
else:
self._expression = UnresolvedAttribute(quote_name(expr1))

Expand Down Expand Up @@ -1446,9 +1466,15 @@ class CaseExpr(Column):
"""

def __init__(
self, expr: CaseWhen, _ast: Optional[proto.Expr] = None, _emit_ast: bool = True
self,
expr: CaseWhen,
json_element: bool = False,
_ast: Optional[proto.Expr] = None,
_emit_ast: bool = True,
) -> None:
super().__init__(expr, _ast=_ast, _emit_ast=_emit_ast)
super().__init__(
expr, json_element=json_element, _ast=_ast, _emit_ast=_emit_ast
)
self._branches = expr.branches

@publicapi
Expand Down
58 changes: 45 additions & 13 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,20 +261,31 @@ def _check_column_parameters(name1: str, name2: Optional[str]) -> None:

@overload
@publicapi
def col(col_name: str, _emit_ast: bool = True) -> Column:
def col(col_name: str, json_element: bool = False, _emit_ast: bool = True) -> Column:
"""Returns the :class:`~snowflake.snowpark.Column` with the specified name.

Args:
col_name: The name of the column.
json_element: Whether the column is a JSON element. If a column is a VARIANT column in Snowflake,
you can dot notation `.` to query the nested json element, e.g., "name.firstname" and "name.lastname".

Example::
>>> df = session.sql("select 1 as a")
>>> df.select(col("a")).collect()
[Row(A=1)]

>>> df = session.sql("select parse_json('{\"firstname\": \"John\", \"lastname\": \"Doe\"}') as name")
>>> df.select(col("name.firstname", json_element=True)).collect()
[Row(FIRSTNAME='John')]
"""
... # pragma: no cover


@overload
@publicapi
def col(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:
def col(
df_alias: str, col_name: str, json_element: bool = False, _emit_ast: bool = True
) -> Column:
"""Returns the :class:`~snowflake.snowpark.Column` with the specified dataframe alias and column name.

Example::
Expand All @@ -286,7 +297,12 @@ def col(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:


@publicapi
def col(name1: str, name2: Optional[str] = None, _emit_ast: bool = True) -> Column:
def col(
name1: str,
name2: Optional[str] = None,
json_element: bool = False,
_emit_ast: bool = True,
) -> Column:

_check_column_parameters(name1, name2)

Expand All @@ -295,27 +311,38 @@ def col(name1: str, name2: Optional[str] = None, _emit_ast: bool = True) -> Colu
ast = create_ast_for_column(name1, name2, "col")

if name2 is None:
return Column(name1, _ast=ast)
return Column(name1, json_element=json_element, _ast=ast)
else:
return Column(name1, name2, _ast=ast)
return Column(name1, name2, json_element=json_element, _ast=ast)


@overload
@publicapi
def column(col_name: str, _emit_ast: bool = True) -> Column:
def column(col_name: str, json_element: bool = False, _emit_ast: bool = True) -> Column:
"""Returns a :class:`~snowflake.snowpark.Column` with the specified name. Alias for col.

Args:
col_name: The name of the column.
json_element: Whether the column is a JSON element. If a column is a VARIANT column in Snowflake,
you can dot notation `.` to query the nested json element, e.g., "name.firstname" and "name.lastname".

Example::
>>> df = session.sql("select 1 as a")
>>> df.select(column("a")).collect()
[Row(A=1)]
>>> df = session.sql("select 1 as a")
>>> df.select(column("a")).collect()
[Row(A=1)]

>>> df = session.sql("select parse_json('{\"firstname\": \"John\", \"lastname\": \"Doe\"}') as name")
>>> df.select(column("name.firstname", json_element=True)).collect()
[Row(FIRSTNAME='John')]
"""
... # pragma: no cover


@overload
@publicapi
def column(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:
def column(
df_alias: str, col_name: str, json_element: bool = False, _emit_ast: bool = True
) -> Column:
"""Returns a :class:`~snowflake.snowpark.Column` with the specified name and dataframe alias name. Alias for col.

Example::
Expand All @@ -327,15 +354,20 @@ def column(df_alias: str, col_name: str, _emit_ast: bool = True) -> Column:


@publicapi
def column(name1: str, name2: Optional[str] = None, _emit_ast: bool = True) -> Column:
def column(
name1: str,
name2: Optional[str] = None,
json_element: bool = False,
_emit_ast: bool = True,
) -> Column:
_check_column_parameters(name1, name2)

ast = create_ast_for_column(name1, name2, "column") if _emit_ast else None

if name2 is None:
return Column(name1, _ast=ast)
return Column(name1, json_element=json_element, _ast=ast)
else:
return Column(name1, name2, _ast=ast)
return Column(name1, name2, json_element=json_element, _ast=ast)


@publicapi
Expand Down
47 changes: 47 additions & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,53 @@
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="querying json element is not supported in local testing",
)
def test_col_json_element(session):
# 2-level deep
df = session.sql(
'select parse_json(\'{"firstname": "John", "lastname": "Doe"}\') as name'
)
Utils.check_answer(
df.select(
col("name.firstname", json_element=True),
col("name.lastname", json_element=True),
),
[Row('"John"', '"Doe"')],
)
Utils.check_answer(
df.select(
col('name."firstname"', json_element=True),
col('NAME."lastname"', json_element=True),
),
[Row('"John"', '"Doe"')],
)
Utils.check_answer(df.select(col("name.FIRSTNAME", json_element=True)), [Row(None)])

# 3-level deep
with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("name:firstname", json_element=True)).collect()

with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("name.firstname")).collect()

df = session.sql('select parse_json(\'{"l1": {"l2": "xyz"}}\') as value')
Utils.check_answer(df.select(col("value.l1.l2", json_element=True)), Row('"xyz"'))
Utils.check_answer(
df.select(col('value."l1"."l2"', json_element=True)), Row('"xyz"')
)
Utils.check_answer(df.select(col("value.L1.l2", json_element=True)), Row(None))
Utils.check_answer(df.select(col("value.l1.L2", json_element=True)), Row(None))

with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("value:l1.l2", json_element=True)).collect()

with pytest.raises(SnowparkSQLException, match="invalid identifier"):
df.select(col("value.l1.l2")).collect()


def test_order(session):
null_data1 = TestData.null_data1(session)
assert null_data1.sort(asc(null_data1["A"])).collect() == [
Expand Down
Loading