Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

- Updated README.md to include instructions on how to verify package signatures using `cosign`.
- Added an option `keep_column_order` for keeping original column order in `DataFrame.with_column` and `DataFrame.with_columns`.
- Added options to column casts that allow renaming or adding fields in StructType columns.
- Added support for `contains_null` parameter to ArrayType.
- Added support for `value_contains_null` parameter to MapType.

Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ def unary_expression_extractor(
),
expr.to,
expr.try_,
expr.is_rename,
expr.is_add,
)
else:
return unary_expression(
Expand Down
12 changes: 11 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@
DEFAULT_ON_NULL = " DEFAULT ON NULL "
ANY = " ANY "
ICEBERG = " ICEBERG "
RENAME_FIELDS = " RENAME FIELDS"
ADD_FIELDS = " ADD FIELDS"

TEMPORARY_STRING_SET = frozenset(["temporary", "temp"])

Expand Down Expand Up @@ -1117,13 +1119,21 @@ def rank_related_function_expression(
)


def cast_expression(child: str, datatype: DataType, try_: bool = False) -> str:
def cast_expression(
child: str,
datatype: DataType,
try_: bool = False,
is_rename: bool = False,
is_add: bool = False,
) -> str:
return (
(TRY_CAST if try_ else CAST)
+ LEFT_PARENTHESIS
+ child
+ AS
+ convert_sp_to_sf_type(datatype)
+ (RENAME_FIELDS if is_rename else "")
+ (ADD_FIELDS if is_add else "")
+ RIGHT_PARENTHESIS
)

Expand Down
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/unary_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,19 @@ class Cast(UnaryExpression):
sql_operator = "CAST"
operator_first = True

def __init__(self, child: Expression, to: DataType, try_: bool = False) -> None:
def __init__(
self,
child: Expression,
to: DataType,
try_: bool = False,
is_rename: bool = False,
is_add: bool = False,
) -> None:
super().__init__(child)
self.to = to
self.try_ = try_
self.is_rename = is_rename
self.is_add = is_add


class UnaryMinus(UnaryExpression):
Expand Down
41 changes: 35 additions & 6 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,17 @@ def __invert__(self) -> "Column":
return Column(Not(self._expression), _ast=expr, _emit_ast=_emit_ast)

def _cast(
self, to: Union[str, DataType], try_: bool = False, _emit_ast: bool = True
self,
to: Union[str, DataType],
try_: bool = False,
_emit_ast: bool = True,
is_rename: bool = False,
is_add: bool = False,
) -> "Column":
if is_add and is_rename:
raise ValueError(
"is_add and is_rename cannot be set to True at the same time"
)
if isinstance(to, str):
to = type_string_to_type_object(to)

Expand All @@ -934,21 +943,41 @@ def _cast(
)
ast.col.CopyFrom(self._ast)
to._fill_ast(ast.to)
return Column(Cast(self._expression, to, try_), _ast=expr, _emit_ast=_emit_ast)
return Column(
Cast(self._expression, to, try_, is_rename, is_add),
_ast=expr,
_emit_ast=_emit_ast,
)

@publicapi
def cast(self, to: Union[str, DataType], _emit_ast: bool = True) -> "Column":
def cast(
self,
to: Union[str, DataType],
_emit_ast: bool = True,
is_rename: bool = False,
is_add: bool = False,
) -> "Column":
"""Casts the value of the Column to the specified data type.
It raises an error when the conversion can not be performed.
"""
return self._cast(to, False, _emit_ast=_emit_ast)
return self._cast(
to, False, _emit_ast=_emit_ast, is_rename=is_rename, is_add=is_add
)

@publicapi
def try_cast(self, to: Union[str, DataType], _emit_ast: bool = True) -> "Column":
def try_cast(
self,
to: Union[str, DataType],
_emit_ast: bool = True,
is_rename: bool = False,
is_add: bool = False,
) -> "Column":
"""Tries to cast the value of the Column to the specified data type.
It returns a NULL value instead of raising an error when the conversion can not be performed.
"""
return self._cast(to, True, _emit_ast=_emit_ast)
return self._cast(
to, True, _emit_ast=_emit_ast, is_rename=is_rename, is_add=is_add
)

@publicapi
def desc(self, _emit_ast: bool = True) -> "Column":
Expand Down
109 changes: 109 additions & 0 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,3 +1449,112 @@ def test_sproc(_session: Session) -> DataFrame:
df = structured_type_session.call(sproc_name)
assert df.schema == expected_schema
assert df.dtypes == expected_dtypes


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Structured types are not supported in Local Testing",
)
def test_cast_structtype_rename(structured_type_session, structured_type_support):
if not structured_type_support:
pytest.skip("Test requires structured type support.")
data = [
({"firstname": "James", "middlename": "", "lastname": "Smith"}, "1991-04-01")
]
schema = StructType(
[
StructField(
"name",
StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
),
),
StructField("dob", StringType(), True),
]
)

schema2 = StructType(
[
StructField("fname", StringType()),
StructField("middlename", StringType()),
StructField("lname", StringType()),
]
)

df = structured_type_session.create_dataframe(data, schema)
Utils.check_answer(
df.select(
col("name").cast(schema2, is_rename=True).as_("new_name"), col("dob")
),
[
Row(
NEW_NAME=Row(fname="James", middlename="", lname="Smith"),
DOB="1991-04-01",
)
],
)
with pytest.raises(
ValueError, match="is_add and is_rename cannot be set to True at the same time"
):
df.select(
col("name").cast(schema2, is_rename=True, is_add=True).as_("new_name"),
col("dob"),
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Structured types are not supported in Local Testing",
)
def test_cast_structtype_add(structured_type_session, structured_type_support):
if not structured_type_support:
pytest.skip("Test requires structured type support.")
data = [
({"firstname": "James", "middlename": "", "lastname": "Smith"}, "1991-04-01")
]
schema = StructType(
[
StructField(
"name",
StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
),
),
StructField("dob", StringType(), True),
]
)

schema2 = StructType(
[
StructField("firstname", StringType()),
StructField("middlename", StringType()),
StructField("lastname", StringType()),
StructField("extra", StringType()),
]
)

df = structured_type_session.create_dataframe(data, schema)
Utils.check_answer(
df.select(col("name").cast(schema2, is_add=True).as_("new_name"), col("dob")),
[
Row(
NEW_NAME=Row(fname="James", middlename="", lname="Smith", extra=None),
DOB="1991-04-01",
)
],
)
with pytest.raises(
ValueError, match="is_add and is_rename cannot be set to True at the same time"
):
df.select(
col("name").cast(schema2, is_rename=True, is_add=True).as_("new_name"),
col("dob"),
)
Loading