Skip to content

Commit d9801ba

Browse files
SNOW-1819523: Add support for expand=True in Series.str.split (#2832)
<!--- Please answer these questions before creating your pull request. Thanks! ---> 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. <!--- In this section, please add a Snowflake Jira issue number. Note that if a corresponding GitHub issue exists, you should still include the Snowflake Jira issue number. For example, for GitHub issue #1400, you should add "SNOW-1335071" here. ---> Fixes SNOW-1819523 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Add support for expand=True in Series.str.split.
1 parent 4d29402 commit d9801ba

File tree

4 files changed

+124
-19
lines changed

4 files changed

+124
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
- %%: A literal '%' character.
8181
- Added support for `Series.between`.
8282
- Added support for `include_groups=False` in `DataFrameGroupBy.apply`.
83+
- Added support for `expand=True` in `Series.str.split`.
8384

8485
#### Bug Fixes
8586

docs/source/modin/supported/series_str_supported.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ the method in the left column.
119119
| ``slice_replace`` | N | |
120120
+-----------------------------+---------------------------------+----------------------------------------------------+
121121
| ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, |
122-
| | | `expand` is set, or `regex` is set. |
122+
| | | or `regex` is set. |
123123
+-----------------------------+---------------------------------+----------------------------------------------------+
124124
| ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool |
125125
| | | value. |

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
dense_rank,
108108
first_value,
109109
floor,
110+
get,
110111
greatest,
111112
hour,
112113
iff,
@@ -16813,10 +16814,6 @@ def str_split(
1681316814
ErrorMessage.not_implemented(
1681416815
"Snowpark pandas doesn't support non-str 'pat' argument"
1681516816
)
16816-
if expand:
16817-
ErrorMessage.not_implemented(
16818-
"Snowpark pandas doesn't support 'expand' argument"
16819-
)
1682016817
if regex:
1682116818
ErrorMessage.not_implemented(
1682216819
"Snowpark pandas doesn't support 'regex' argument"
@@ -16864,6 +16861,12 @@ def output_col(
1686416861
if np.isnan(n):
1686516862
# Follow pandas behavior
1686616863
return pandas_lit(np.nan)
16864+
elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
16865+
# Follow pandas behavior, which based on our experiments, leaves the input column as is
16866+
# whenever the above condition is satisfied.
16867+
new_col = iff(
16868+
column.is_null(), pandas_lit(None), array_construct(column)
16869+
)
1686716870
elif n <= 0:
1686816871
# If all possible splits are requested, we just use SQL's split function.
1686916872
new_col = builtin("split")(new_col, pandas_lit(new_pat))
@@ -16907,9 +16910,93 @@ def output_col(
1690716910
)
1690816911
return self._replace_non_str(column, new_col)
1690916912

16910-
new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
16911-
lambda col_name: output_col(col_name, pat, n)
16912-
)
16913+
def output_cols(
16914+
column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int
16915+
) -> list[SnowparkColumn]:
16916+
"""
16917+
Returns the list of columns that the input column will be split into.
16918+
This is only used when expand=True.
16919+
Args:
16920+
column : SnowparkColumn
16921+
Input column
16922+
pat : str
16923+
String to split on
16924+
n : int
16925+
Limit on the number of output splits
16926+
max_splits : int
16927+
Maximum number of achievable splits across all values in the input column.
16928+
This is needed to be able to pad rows with fewer splits than desired with nulls.
16929+
"""
16930+
col = output_col(column, pat, n)
16931+
final_splits = 0
16932+
16933+
if np.isnan(n):
16934+
# Follow pandas behavior
16935+
final_splits = 1
16936+
elif n <= 0:
16937+
final_splits = max_splits
16938+
else:
16939+
final_splits = min(n + 1, max_splits)
16940+
16941+
if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
16942+
# Follow pandas behavior, which based on our experiments, leaves the input column as is
16943+
# whenever the above condition is satisfied.
16944+
final_splits = 1
16945+
16946+
return [
16947+
iff(
16948+
array_size(col) > pandas_lit(i),
16949+
get(col, pandas_lit(i)),
16950+
pandas_lit(None),
16951+
)
16952+
for i in range(final_splits)
16953+
]
16954+
16955+
def get_max_splits() -> int:
16956+
"""
16957+
Returns the maximum number of splits achievable
16958+
across all values stored in the input column.
16959+
"""
16960+
splits_as_list_frame = self.str_split(
16961+
pat=pat,
16962+
n=-1,
16963+
expand=False,
16964+
regex=regex,
16965+
)._modin_frame
16966+
16967+
split_counts_frame = splits_as_list_frame.append_column(
16968+
"split_counts",
16969+
array_size(
16970+
col(
16971+
splits_as_list_frame.data_column_snowflake_quoted_identifiers[0]
16972+
)
16973+
),
16974+
)
16975+
16976+
max_count_rows = split_counts_frame.ordered_dataframe.agg(
16977+
max_(
16978+
col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1])
16979+
).as_("max_count")
16980+
).collect()
16981+
16982+
return max_count_rows[0][0]
16983+
16984+
if expand:
16985+
cols = output_cols(
16986+
col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]),
16987+
pat,
16988+
n,
16989+
get_max_splits(),
16990+
)
16991+
new_internal_frame = self._modin_frame.project_columns(
16992+
list(range(len(cols))),
16993+
cols,
16994+
)
16995+
else:
16996+
new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
16997+
lambda col_name: output_col(col_name, pat, n)
16998+
)
16999+
1691317000
return SnowflakeQueryCompiler(new_internal_frame)
1691417001

1691517002
def str_rsplit(

tests/integ/modin/series/test_str_accessor.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
from snowflake.snowpark._internal.utils import TempObjectType
1414
import snowflake.snowpark.modin.plugin # noqa: F401
15-
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result
15+
from tests.integ.modin.utils import (
16+
assert_series_equal,
17+
eval_snowpark_pandas_result,
18+
)
1619
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
1720

1821
TEST_DATA = [
@@ -367,10 +370,12 @@ def test_str_replace_neg(pat, n, repl, error):
367370
snow_ser.str.replace(pat=pat, repl=repl, n=n)
368371

369372

370-
@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
373+
@pytest.mark.parametrize(
374+
"pat", [None, "a", "ab", "abc", "non_occurrence_pat", "|", "%"]
375+
)
371376
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
372377
@sql_count_checker(query_count=1)
373-
def test_str_split(pat, n):
378+
def test_str_split_expand_false(pat, n):
374379
native_ser = native_pd.Series(TEST_DATA)
375380
snow_ser = pd.Series(native_ser)
376381
eval_snowpark_pandas_result(
@@ -380,6 +385,19 @@ def test_str_split(pat, n):
380385
)
381386

382387

388+
@pytest.mark.parametrize("pat", [None, "a", "ab", "abc", "no_occurrence_pat", "|", "%"])
389+
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
390+
@sql_count_checker(query_count=2)
391+
def test_str_split_expand_true(pat, n):
392+
native_ser = native_pd.Series(TEST_DATA)
393+
snow_ser = pd.Series(native_ser)
394+
eval_snowpark_pandas_result(
395+
snow_ser,
396+
native_ser,
397+
lambda ser: ser.str.split(pat=pat, n=n, expand=True, regex=None),
398+
)
399+
400+
383401
@pytest.mark.parametrize("regex", [None, True])
384402
@pytest.mark.xfail(
385403
reason="Snowflake SQL's split function does not support regex", strict=True
@@ -395,21 +413,20 @@ def test_str_split_regex(regex):
395413

396414

397415
@pytest.mark.parametrize(
398-
"pat, n, expand, error",
416+
"pat, n, error",
399417
[
400-
("", 1, False, ValueError),
401-
(re.compile("a"), 1, False, NotImplementedError),
402-
(-2.0, 1, False, NotImplementedError),
403-
("a", "a", False, NotImplementedError),
404-
("a", 1, True, NotImplementedError),
418+
("", 1, ValueError),
419+
(re.compile("a"), 1, NotImplementedError),
420+
(-2.0, 1, NotImplementedError),
421+
("a", "a", NotImplementedError),
405422
],
406423
)
407424
@sql_count_checker(query_count=0)
408-
def test_str_split_neg(pat, n, expand, error):
425+
def test_str_split_neg(pat, n, error):
409426
native_ser = native_pd.Series(TEST_DATA)
410427
snow_ser = pd.Series(native_ser)
411428
with pytest.raises(error):
412-
snow_ser.str.split(pat=pat, n=n, expand=expand, regex=False)
429+
snow_ser.str.split(pat=pat, n=n, expand=False, regex=False)
413430

414431

415432
@pytest.mark.parametrize("func", ["isdigit", "islower", "isupper", "lower", "upper"])

0 commit comments

Comments
 (0)