Skip to content

Commit 0ff01c4

Browse files
SNOW-1876115: Allow column to cast to StructType. (#2864)
1 parent 1cbebc1 commit 0ff01c4

File tree

6 files changed

+182
-8
lines changed

6 files changed

+182
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
- Updated README.md to include instructions on how to verify package signatures using `cosign`.
5252
- Added an option `keep_column_order` for keeping original column order in `DataFrame.with_column` and `DataFrame.with_columns`.
53+
- Added options to column casts that allow renaming or adding fields in StructType columns.
5354
- Added support for `contains_null` parameter to ArrayType.
5455
- Added support for creating a temporary view via `DataFrame.create_or_replace_temp_view` from a DataFrame created by reading a file from a stage.
5556
- Added support for `value_contains_null` parameter to MapType.

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,8 @@ def unary_expression_extractor(
661661
),
662662
expr.to,
663663
expr.try_,
664+
expr.is_rename,
665+
expr.is_add,
664666
)
665667
else:
666668
return unary_expression(

src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@
185185
DEFAULT_ON_NULL = " DEFAULT ON NULL "
186186
ANY = " ANY "
187187
ICEBERG = " ICEBERG "
188+
RENAME_FIELDS = " RENAME FIELDS"
189+
ADD_FIELDS = " ADD FIELDS"
188190

189191
TEMPORARY_STRING_SET = frozenset(["temporary", "temp"])
190192

@@ -1117,13 +1119,21 @@ def rank_related_function_expression(
11171119
)
11181120

11191121

1120-
def cast_expression(child: str, datatype: DataType, try_: bool = False) -> str:
1122+
def cast_expression(
1123+
child: str,
1124+
datatype: DataType,
1125+
try_: bool = False,
1126+
is_rename: bool = False,
1127+
is_add: bool = False,
1128+
) -> str:
11211129
return (
11221130
(TRY_CAST if try_ else CAST)
11231131
+ LEFT_PARENTHESIS
11241132
+ child
11251133
+ AS
11261134
+ convert_sp_to_sf_type(datatype)
1135+
+ (RENAME_FIELDS if is_rename else "")
1136+
+ (ADD_FIELDS if is_add else "")
11271137
+ RIGHT_PARENTHESIS
11281138
)
11291139

src/snowflake/snowpark/_internal/analyzer/unary_expression.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,19 @@ class Cast(UnaryExpression):
4949
sql_operator = "CAST"
5050
operator_first = True
5151

52-
def __init__(self, child: Expression, to: DataType, try_: bool = False) -> None:
52+
def __init__(
53+
self,
54+
child: Expression,
55+
to: DataType,
56+
try_: bool = False,
57+
is_rename: bool = False,
58+
is_add: bool = False,
59+
) -> None:
5360
super().__init__(child)
5461
self.to = to
5562
self.try_ = try_
63+
self.is_rename = is_rename
64+
self.is_add = is_add
5665

5766

5867
class UnaryMinus(UnaryExpression):

src/snowflake/snowpark/column.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,17 @@ def __invert__(self) -> "Column":
915915
return Column(Not(self._expression), _ast=expr, _emit_ast=_emit_ast)
916916

917917
def _cast(
918-
self, to: Union[str, DataType], try_: bool = False, _emit_ast: bool = True
918+
self,
919+
to: Union[str, DataType],
920+
try_: bool = False,
921+
rename_fields: bool = False,
922+
add_fields: bool = False,
923+
_emit_ast: bool = True,
919924
) -> "Column":
925+
if add_fields and rename_fields:
926+
raise ValueError(
927+
"is_add and is_rename cannot be set to True at the same time"
928+
)
920929
if isinstance(to, str):
921930
to = type_string_to_type_object(to)
922931

@@ -934,21 +943,49 @@ def _cast(
934943
)
935944
ast.col.CopyFrom(self._ast)
936945
to._fill_ast(ast.to)
937-
return Column(Cast(self._expression, to, try_), _ast=expr, _emit_ast=_emit_ast)
946+
return Column(
947+
Cast(self._expression, to, try_, rename_fields, add_fields),
948+
_ast=expr,
949+
_emit_ast=_emit_ast,
950+
)
938951

939952
@publicapi
940-
def cast(self, to: Union[str, DataType], _emit_ast: bool = True) -> "Column":
953+
def cast(
954+
self,
955+
to: Union[str, DataType],
956+
rename_fields: bool = False,
957+
add_fields: bool = False,
958+
_emit_ast: bool = True,
959+
) -> "Column":
941960
"""Casts the value of the Column to the specified data type.
942961
It raises an error when the conversion can not be performed.
943962
"""
944-
return self._cast(to, False, _emit_ast=_emit_ast)
963+
return self._cast(
964+
to,
965+
False,
966+
rename_fields=rename_fields,
967+
add_fields=add_fields,
968+
_emit_ast=_emit_ast,
969+
)
945970

946971
@publicapi
947-
def try_cast(self, to: Union[str, DataType], _emit_ast: bool = True) -> "Column":
972+
def try_cast(
973+
self,
974+
to: Union[str, DataType],
975+
rename_fields: bool = False,
976+
add_fields: bool = False,
977+
_emit_ast: bool = True,
978+
) -> "Column":
948979
"""Tries to cast the value of the Column to the specified data type.
949980
It returns a NULL value instead of raising an error when the conversion can not be performed.
950981
"""
951-
return self._cast(to, True, _emit_ast=_emit_ast)
982+
return self._cast(
983+
to,
984+
True,
985+
rename_fields=rename_fields,
986+
add_fields=add_fields,
987+
_emit_ast=_emit_ast,
988+
)
952989

953990
@publicapi
954991
def desc(self, _emit_ast: bool = True) -> "Column":

tests/integ/scala/test_datatype_suite.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,3 +1449,118 @@ def test_sproc(_session: Session) -> DataFrame:
14491449
df = structured_type_session.call(sproc_name)
14501450
assert df.schema == expected_schema
14511451
assert df.dtypes == expected_dtypes
1452+
1453+
1454+
@pytest.mark.skipif(
1455+
"config.getoption('local_testing_mode', default=False)",
1456+
reason="Structured types are not supported in Local Testing",
1457+
)
1458+
def test_cast_structtype_rename(structured_type_session, structured_type_support):
1459+
if not structured_type_support:
1460+
pytest.skip("Test requires structured type support.")
1461+
data = [
1462+
({"firstname": "James", "middlename": "", "lastname": "Smith"}, "1991-04-01")
1463+
]
1464+
schema = StructType(
1465+
[
1466+
StructField(
1467+
"name",
1468+
StructType(
1469+
[
1470+
StructField("firstname", StringType(), True),
1471+
StructField("middlename", StringType(), True),
1472+
StructField("lastname", StringType(), True),
1473+
]
1474+
),
1475+
),
1476+
StructField("dob", StringType(), True),
1477+
]
1478+
)
1479+
1480+
schema2 = StructType(
1481+
[
1482+
StructField("fname", StringType()),
1483+
StructField("middlename", StringType()),
1484+
StructField("lname", StringType()),
1485+
]
1486+
)
1487+
1488+
df = structured_type_session.create_dataframe(data, schema)
1489+
Utils.check_answer(
1490+
df.select(
1491+
col("name").cast(schema2, rename_fields=True).as_("new_name"), col("dob")
1492+
),
1493+
[
1494+
Row(
1495+
NEW_NAME=Row(fname="James", middlename="", lname="Smith"),
1496+
DOB="1991-04-01",
1497+
)
1498+
],
1499+
)
1500+
with pytest.raises(
1501+
ValueError, match="is_add and is_rename cannot be set to True at the same time"
1502+
):
1503+
df.select(
1504+
col("name")
1505+
.cast(schema2, rename_fields=True, add_fields=True)
1506+
.as_("new_name"),
1507+
col("dob"),
1508+
)
1509+
1510+
1511+
@pytest.mark.skipif(
1512+
"config.getoption('local_testing_mode', default=False)",
1513+
reason="Structured types are not supported in Local Testing",
1514+
)
1515+
def test_cast_structtype_add(structured_type_session, structured_type_support):
1516+
if not structured_type_support:
1517+
pytest.skip("Test requires structured type support.")
1518+
data = [
1519+
({"firstname": "James", "middlename": "", "lastname": "Smith"}, "1991-04-01")
1520+
]
1521+
schema = StructType(
1522+
[
1523+
StructField(
1524+
"name",
1525+
StructType(
1526+
[
1527+
StructField("firstname", StringType(), True),
1528+
StructField("middlename", StringType(), True),
1529+
StructField("lastname", StringType(), True),
1530+
]
1531+
),
1532+
),
1533+
StructField("dob", StringType(), True),
1534+
]
1535+
)
1536+
1537+
schema2 = StructType(
1538+
[
1539+
StructField("firstname", StringType()),
1540+
StructField("middlename", StringType()),
1541+
StructField("lastname", StringType()),
1542+
StructField("extra", StringType()),
1543+
]
1544+
)
1545+
1546+
df = structured_type_session.create_dataframe(data, schema)
1547+
Utils.check_answer(
1548+
df.select(
1549+
col("name").cast(schema2, add_fields=True).as_("new_name"), col("dob")
1550+
),
1551+
[
1552+
Row(
1553+
NEW_NAME=Row(fname="James", middlename="", lname="Smith", extra=None),
1554+
DOB="1991-04-01",
1555+
)
1556+
],
1557+
)
1558+
with pytest.raises(
1559+
ValueError, match="is_add and is_rename cannot be set to True at the same time"
1560+
):
1561+
df.select(
1562+
col("name")
1563+
.cast(schema2, rename_fields=True, add_fields=True)
1564+
.as_("new_name"),
1565+
col("dob"),
1566+
)

0 commit comments

Comments
 (0)