Skip to content

Commit 226b9d9

Browse files
SNOW-2444072: Add support for groupby.apply in faster pandas (#3933)
1 parent e5f2d99 commit 226b9d9

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
- `groupby.var`
156156
- `groupby.nunique`
157157
- `groupby.size`
158+
- `groupby.apply`
158159
- `drop_duplicates`
159160
- Reuse row count from the relaxed query compiler in `get_axis_len`.
160161

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4886,6 +4886,52 @@ def convert_func_to_agg_func_info(
48864886
return query_compiler if as_index else query_compiler.reset_index(drop=drop)
48874887

48884888
def groupby_apply(
4889+
self,
4890+
by: Any,
4891+
agg_func: AggFuncType,
4892+
axis: int,
4893+
groupby_kwargs: dict[str, Any],
4894+
agg_args: Any,
4895+
agg_kwargs: dict[str, Any],
4896+
series_groupby: bool,
4897+
include_groups: bool,
4898+
force_single_group: bool = False,
4899+
force_list_like_to_series: bool = False,
4900+
) -> "SnowflakeQueryCompiler":
4901+
"""
4902+
Wrapper around _groupby_apply_internal to be supported in faster pandas.
4903+
"""
4904+
relaxed_query_compiler = None
4905+
if self._relaxed_query_compiler is not None:
4906+
relaxed_query_compiler = (
4907+
self._relaxed_query_compiler._groupby_apply_internal(
4908+
by=by,
4909+
agg_func=agg_func,
4910+
axis=axis,
4911+
groupby_kwargs=groupby_kwargs,
4912+
agg_args=agg_args,
4913+
agg_kwargs=agg_kwargs,
4914+
series_groupby=series_groupby,
4915+
include_groups=include_groups,
4916+
force_single_group=force_single_group,
4917+
force_list_like_to_series=force_list_like_to_series,
4918+
)
4919+
)
4920+
qc = self._groupby_apply_internal(
4921+
by=by,
4922+
agg_func=agg_func,
4923+
axis=axis,
4924+
groupby_kwargs=groupby_kwargs,
4925+
agg_args=agg_args,
4926+
agg_kwargs=agg_kwargs,
4927+
series_groupby=series_groupby,
4928+
include_groups=include_groups,
4929+
force_single_group=force_single_group,
4930+
force_list_like_to_series=force_list_like_to_series,
4931+
)
4932+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4933+
4934+
def _groupby_apply_internal(
48894935
self,
48904936
by: Any,
48914937
agg_func: Callable,

tests/integ/modin/test_faster_pandas.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,37 @@ def test_groupby_agg(session, func):
397397
assert_frame_equal(snow_result4, native_result4, check_dtype=False)
398398

399399

400+
@sql_count_checker(query_count=9, join_count=1, udtf_count=1)
401+
def test_groupby_apply(session):
402+
# create tables
403+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
404+
session.create_dataframe(
405+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
406+
).write.save_as_table(table_name, table_type="temp")
407+
408+
# create snow dataframes
409+
df = pd.read_snowflake(table_name).sort_values("B", ignore_index=True)
410+
snow_result = df.groupby("A").apply(lambda x: x + 1)
411+
412+
# verify that the input dataframe has a populated relaxed query compiler
413+
assert df._query_compiler._relaxed_query_compiler is not None
414+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
415+
# verify that the output dataframe also has a populated relaxed query compiler
416+
assert snow_result._query_compiler._relaxed_query_compiler is not None
417+
assert (
418+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
419+
)
420+
421+
# create pandas dataframes
422+
native_df = df.to_pandas()
423+
native_result = native_df.groupby("A").apply(lambda x: x + 1)
424+
425+
# compare results
426+
assert_frame_equal(
427+
snow_result, native_result, check_dtype=False, check_index_type=False
428+
)
429+
430+
400431
@sql_count_checker(query_count=5)
401432
def test_iloc_head(session):
402433
# create tables

0 commit comments

Comments
 (0)