Skip to content

Commit 0c077b7

Browse files
SNOW-2401303: Add support for str.contains/startswith/endswith/slice in faster pandas (#3868)
1 parent 88d3f0e commit 0c077b7

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
- Improved performance of `Series.to_snowflake` and `pd.to_snowflake(series)` for large data by uploading data via a parquet file. You can control the dataset size at which Snowpark pandas switches to parquet with the variable `modin.config.PandasToSnowflakeParquetThresholdBytes`.
4646
- Set `cte_optimization_enabled` to True for all Snowpark pandas sessions.
4747
- Add support for `isna`, `isnull`, `notna`, `notnull` in faster pandas.
48+
- Add support for `str.contains`, `str.startswith`, `str.endswith`, and `str.slice` in faster pandas.
4849

4950
## 1.40.0 (2025-10-02)
5051

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17116,6 +17116,21 @@ def str_encode(self, encoding: str, errors: str) -> None:
1711617116

1711717117
def str_startswith(
1711817118
self, pat: Union[str, tuple], na: object = None
17119+
) -> "SnowflakeQueryCompiler":
17120+
"""
17121+
Wrapper around _str_startswith_internal to be supported in faster pandas.
17122+
"""
17123+
relaxed_query_compiler = None
17124+
if self._relaxed_query_compiler is not None:
17125+
relaxed_query_compiler = (
17126+
self._relaxed_query_compiler._str_startswith_internal(pat=pat, na=na)
17127+
)
17128+
17129+
qc = self._str_startswith_internal(pat=pat, na=na)
17130+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
17131+
17132+
def _str_startswith_internal(
17133+
self, pat: Union[str, tuple], na: object = None
1711917134
) -> "SnowflakeQueryCompiler":
1712017135
"""
1712117136
Test if the start of each string element matches a pattern.
@@ -17135,6 +17150,21 @@ def str_startswith(
1713517150

1713617151
def str_endswith(
1713717152
self, pat: Union[str, tuple], na: object = None
17153+
) -> "SnowflakeQueryCompiler":
17154+
"""
17155+
Wrapper around _str_endswith_internal to be supported in faster pandas.
17156+
"""
17157+
relaxed_query_compiler = None
17158+
if self._relaxed_query_compiler is not None:
17159+
relaxed_query_compiler = (
17160+
self._relaxed_query_compiler._str_endswith_internal(pat=pat, na=na)
17161+
)
17162+
17163+
qc = self._str_endswith_internal(pat=pat, na=na)
17164+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
17165+
17166+
def _str_endswith_internal(
17167+
self, pat: Union[str, tuple], na: object = None
1713817168
) -> "SnowflakeQueryCompiler":
1713917169
"""
1714017170
Test if the end of each string element matches a pattern.
@@ -17490,6 +17520,38 @@ def str_contains(
1749017520
flags: int = 0,
1749117521
na: object = None,
1749217522
regex: bool = True,
17523+
) -> "SnowflakeQueryCompiler":
17524+
"""
17525+
Wrapper around _str_contains_internal to be supported in faster pandas.
17526+
"""
17527+
relaxed_query_compiler = None
17528+
if self._relaxed_query_compiler is not None:
17529+
relaxed_query_compiler = (
17530+
self._relaxed_query_compiler._str_contains_internal(
17531+
pat=pat,
17532+
case=case,
17533+
flags=flags,
17534+
na=na,
17535+
regex=regex,
17536+
)
17537+
)
17538+
17539+
qc = self._str_contains_internal(
17540+
pat=pat,
17541+
case=case,
17542+
flags=flags,
17543+
na=na,
17544+
regex=regex,
17545+
)
17546+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
17547+
17548+
def _str_contains_internal(
17549+
self,
17550+
pat: str,
17551+
case: bool = True,
17552+
flags: int = 0,
17553+
na: object = None,
17554+
regex: bool = True,
1749317555
) -> "SnowflakeQueryCompiler":
1749417556
"""
1749517557
Test if pattern or regex is contained within a string of a Series or Index.
@@ -17851,6 +17913,29 @@ def str_slice(
1785117913
start: Optional[int] = None,
1785217914
stop: Optional[int] = None,
1785317915
step: Optional[int] = None,
17916+
) -> "SnowflakeQueryCompiler":
17917+
"""
17918+
Wrapper around _str_slice_internal to be supported in faster pandas.
17919+
"""
17920+
relaxed_query_compiler = None
17921+
if self._relaxed_query_compiler is not None:
17922+
relaxed_query_compiler = self._relaxed_query_compiler._str_slice_internal(
17923+
start=start,
17924+
stop=stop,
17925+
step=step,
17926+
)
17927+
qc = self._str_slice_internal(
17928+
start=start,
17929+
stop=stop,
17930+
step=step,
17931+
)
17932+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
17933+
17934+
def _str_slice_internal(
17935+
self,
17936+
start: Optional[int] = None,
17937+
stop: Optional[int] = None,
17938+
step: Optional[int] = None,
1785417939
) -> "SnowflakeQueryCompiler":
1785517940
"""
1785617941
Slice substrings from each element in the Series or Index.

tests/integ/modin/test_faster_pandas.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
_SNOWPARK_PANDAS_DUMMY_ROW_POS_OPTIMIZATION_ENABLED,
1515
Session,
1616
)
17-
from tests.integ.modin.utils import assert_frame_equal, assert_index_equal
17+
from tests.integ.modin.utils import (
18+
assert_frame_equal,
19+
assert_index_equal,
20+
assert_series_equal,
21+
)
1822
from tests.integ.utils.sql_counter import sql_count_checker
1923
from tests.utils import Utils
2024

@@ -278,6 +282,94 @@ def test_isin_series(session):
278282
assert_frame_equal(snow_result, native_result, check_dtype=False)
279283

280284

285+
@sql_count_checker(query_count=3)
286+
def test_str_contains(session):
287+
# create tables
288+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
289+
session.create_dataframe(
290+
native_pd.DataFrame([["abc"], ["def"], ["ghi"]], columns=["A"])
291+
).write.save_as_table(table_name, table_type="temp")
292+
293+
# create snow dataframes
294+
df = pd.read_snowflake(table_name)
295+
snow_result = df["A"].str.contains("ab")
296+
297+
# verify that the input dataframe has a populated relaxed query compiler
298+
assert df._query_compiler._relaxed_query_compiler is not None
299+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
300+
# verify that the output dataframe also has a populated relaxed query compiler
301+
assert snow_result._query_compiler._relaxed_query_compiler is not None
302+
assert (
303+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
304+
)
305+
306+
# create pandas dataframes
307+
native_df = df.to_pandas()
308+
native_result = native_df["A"].str.contains("ab")
309+
310+
# compare results
311+
assert_series_equal(snow_result, native_result)
312+
313+
314+
@pytest.mark.parametrize("func", ["startswith", "endswith"])
315+
@sql_count_checker(query_count=3)
316+
def test_str_startswith_endswith(session, func):
317+
# create tables
318+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
319+
session.create_dataframe(
320+
native_pd.DataFrame([["abc"], ["def"], ["cba"]], columns=["A"])
321+
).write.save_as_table(table_name, table_type="temp")
322+
323+
# create snow dataframes
324+
df = pd.read_snowflake(table_name)
325+
snow_result = getattr(df["A"].str, func)("c")
326+
327+
# verify that the input dataframe has a populated relaxed query compiler
328+
assert df._query_compiler._relaxed_query_compiler is not None
329+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
330+
# verify that the output dataframe also has a populated relaxed query compiler
331+
assert snow_result._query_compiler._relaxed_query_compiler is not None
332+
assert (
333+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
334+
)
335+
336+
# create pandas dataframes
337+
native_df = df.to_pandas()
338+
native_result = getattr(native_df["A"].str, func)("c")
339+
340+
# compare results
341+
assert_series_equal(snow_result, native_result)
342+
343+
344+
@sql_count_checker(query_count=3)
345+
def test_str_slice(session):
346+
# create tables
347+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
348+
session.create_dataframe(
349+
native_pd.DataFrame([["abc"], ["def"], ["ghi"]], columns=["A"])
350+
).write.save_as_table(table_name, table_type="temp")
351+
352+
# create snow dataframes
353+
df = pd.read_snowflake(table_name)
354+
snow_result = df["A"].str.slice(0, 2, 1)
355+
356+
# verify that the input dataframe has a populated relaxed query compiler
357+
assert df._query_compiler._relaxed_query_compiler is not None
358+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
359+
# verify that the output dataframe also has a populated relaxed query compiler
360+
assert snow_result._query_compiler._relaxed_query_compiler is not None
361+
assert (
362+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
363+
)
364+
365+
# create pandas dataframes
366+
native_df = df.to_pandas()
367+
native_result = native_df["A"].str.slice(0, 2, 1)
368+
369+
# compare results
370+
assert_series_equal(snow_result, native_result)
371+
372+
281373
@sql_count_checker(query_count=0)
282374
def test_dummy_row_pos_optimization_enabled_on_session(db_parameters):
283375
with Session.builder.configs(db_parameters).create() as new_session:

0 commit comments

Comments
 (0)