Skip to content

Commit 2b668f2

Browse files
address comments
1 parent fe00d90 commit 2b668f2

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16835,6 +16835,12 @@ def output_col(
1683516835
if np.isnan(n):
1683616836
# Follow pandas behavior
1683716837
return pandas_lit(np.nan)
16838+
elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
16839+
# Follow pandas behavior, which seems to leave the input column as is
16840+
# whenever the above condition is satisfied.
16841+
new_col = iff(
16842+
column.is_null(), pandas_lit(None), array_construct(column)
16843+
)
1683816844
elif n <= 0:
1683916845
# If all possible splits are requested, we just use SQL's split function.
1684016846
new_col = builtin("split")(new_col, pandas_lit(new_pat))
@@ -16879,28 +16885,47 @@ def output_col(
1687916885
return self._replace_non_str(column, new_col)
1688016886

1688116887
def output_cols(
16882-
column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int
16888+
column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int
1688316889
) -> list[SnowparkColumn]:
16890+
"""
16891+
Returns the list of columns that the input column will be split into.
16892+
This is only used when expand=True.
16893+
Args:
16894+
column: input column
16895+
pat: string to split on
16896+
n: limit on the number of output splits
16897+
max_splits: maximum number of achievable splits across all values in the input column
16898+
"""
1688416899
col = output_col(column, pat, n)
16885-
final_n_cols = 0
16900+
final_splits = 0
16901+
1688616902
if np.isnan(n):
1688716903
# Follow pandas behavior
16888-
final_n_cols = 1
16904+
final_splits = 1
1688916905
elif n <= 0:
16890-
final_n_cols = max_n_cols
16906+
final_splits = max_splits
1689116907
else:
16892-
final_n_cols = min(n + 1, max_n_cols)
16908+
final_splits = min(n + 1, max_splits)
16909+
16910+
if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
16911+
# Follow pandas behavior, which seems to leave the input column as is
16912+
# whenever the above condition is satisfied.
16913+
final_splits = 1
1689316914

1689416915
return [
1689516916
iff(
1689616917
array_size(col) > pandas_lit(i),
1689716918
get(col, pandas_lit(i)),
1689816919
pandas_lit(None),
1689916920
)
16900-
for i in range(final_n_cols)
16921+
for i in range(final_splits)
1690116922
]
1690216923

16903-
def max_n_cols() -> int:
16924+
def get_max_splits() -> int:
16925+
"""
16926+
Returns the maximum number of splits achievable
16927+
across all values stored in the input column.
16928+
"""
1690416929
splits_as_list_frame = self.str_split(
1690516930
pat=pat,
1690616931
n=-1,
@@ -16930,10 +16955,10 @@ def max_n_cols() -> int:
1693016955
col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]),
1693116956
pat,
1693216957
n,
16933-
max_n_cols(),
16958+
get_max_splits(),
1693416959
)
1693516960
new_internal_frame = self._modin_frame.project_columns(
16936-
[f"{i}" for i in range(len(cols))],
16961+
list(range(len(cols))),
1693716962
cols,
1693816963
)
1693916964
else:

tests/integ/modin/series/test_str_accessor.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from snowflake.snowpark._internal.utils import TempObjectType
1414
import snowflake.snowpark.modin.plugin # noqa: F401
1515
from tests.integ.modin.utils import (
16-
assert_frame_equal,
1716
assert_series_equal,
1817
eval_snowpark_pandas_result,
1918
)
@@ -371,7 +370,7 @@ def test_str_replace_neg(pat, n, repl, error):
371370
snow_ser.str.replace(pat=pat, repl=repl, n=n)
372371

373372

374-
@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
373+
@pytest.mark.parametrize("pat", [None, "a", "ab", "non_occurrence_pat", "|", "%"])
375374
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
376375
@sql_count_checker(query_count=1)
377376
def test_str_split_expand_false(pat, n):
@@ -384,21 +383,17 @@ def test_str_split_expand_false(pat, n):
384383
)
385384

386385

387-
@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
386+
@pytest.mark.parametrize("pat", [None, "a", "ab", "no_occurrence_pat", "|", "%"])
388387
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
389388
@sql_count_checker(query_count=2)
390389
def test_str_split_expand_true(pat, n):
391390
native_ser = native_pd.Series(TEST_DATA)
392391
snow_ser = pd.Series(native_ser)
393-
native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None)
394-
snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None)
395-
# Currently Snowpark pandas uses an Index object with string values for columns,
396-
# while native pandas uses a RangeIndex.
397-
# So we make sure that all corresponding values in the two columns objects are identical
398-
# (after casting from string to int).
399-
assert all(snow_df.columns.astype(int).values == native_df.columns.values)
400-
snow_df.columns = native_df.columns
401-
assert_frame_equal(snow_df, native_df, check_dtype=False)
392+
eval_snowpark_pandas_result(
393+
snow_ser,
394+
native_ser,
395+
lambda ser: ser.str.split(pat=pat, n=n, expand=True, regex=None),
396+
)
402397

403398

404399
@pytest.mark.parametrize("regex", [None, True])

0 commit comments

Comments
 (0)