Skip to content

Commit 634d0cf

Browse files
SNOW-2444072: Add support for groupby.apply in faster pandas
1 parent 33a8ab6 commit 634d0cf

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@
152152
- `groupby.median`
153153
- `groupby.std`
154154
- `groupby.var`
155+
- `groupby.nunique`
156+
- `groupby.size`
157+
- `groupby.apply`
155158
- `drop_duplicates`
156159
- Reuse row count from the relaxed query compiler in `get_axis_len`.
157160

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4886,6 +4886,52 @@ def convert_func_to_agg_func_info(
48864886
return query_compiler if as_index else query_compiler.reset_index(drop=drop)
48874887

48884888
def groupby_apply(
4889+
self,
4890+
by: Any,
4891+
agg_func: AggFuncType,
4892+
axis: int,
4893+
groupby_kwargs: dict[str, Any],
4894+
agg_args: Any,
4895+
agg_kwargs: dict[str, Any],
4896+
series_groupby: bool,
4897+
include_groups: bool,
4898+
force_single_group: bool = False,
4899+
force_list_like_to_series: bool = False,
4900+
) -> "SnowflakeQueryCompiler":
4901+
"""
4902+
Wrapper around _groupby_apply_internal to be supported in faster pandas.
4903+
"""
4904+
relaxed_query_compiler = None
4905+
if self._relaxed_query_compiler is not None:
4906+
relaxed_query_compiler = (
4907+
self._relaxed_query_compiler._groupby_apply_internal(
4908+
by=by,
4909+
agg_func=agg_func,
4910+
axis=axis,
4911+
groupby_kwargs=groupby_kwargs,
4912+
agg_args=agg_args,
4913+
agg_kwargs=agg_kwargs,
4914+
series_groupby=series_groupby,
4915+
include_groups=include_groups,
4916+
force_single_group=force_single_group,
4917+
force_list_like_to_series=force_list_like_to_series,
4918+
)
4919+
)
4920+
qc = self._groupby_apply_internal(
4921+
by=by,
4922+
agg_func=agg_func,
4923+
axis=axis,
4924+
groupby_kwargs=groupby_kwargs,
4925+
agg_args=agg_args,
4926+
agg_kwargs=agg_kwargs,
4927+
series_groupby=series_groupby,
4928+
include_groups=include_groups,
4929+
force_single_group=force_single_group,
4930+
force_list_like_to_series=force_list_like_to_series,
4931+
)
4932+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4933+
4934+
def _groupby_apply_internal(
48894935
self,
48904936
by: Any,
48914937
agg_func: Callable,

tests/integ/modin/test_faster_pandas.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,37 @@ def test_groupby_agg(session, func):
392392
assert_frame_equal(snow_result4, native_result4, check_dtype=False)
393393

394394

395+
@sql_count_checker(query_count=9, join_count=1, udtf_count=1)
396+
def test_groupby_apply(session):
397+
# create tables
398+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
399+
session.create_dataframe(
400+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
401+
).write.save_as_table(table_name, table_type="temp")
402+
403+
# create snow dataframes
404+
df = pd.read_snowflake(table_name).sort_values("B", ignore_index=True)
405+
snow_result = df.groupby("A").apply(lambda x: x + 1)
406+
407+
# verify that the input dataframe has a populated relaxed query compiler
408+
assert df._query_compiler._relaxed_query_compiler is not None
409+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
410+
# verify that the output dataframe also has a populated relaxed query compiler
411+
assert snow_result._query_compiler._relaxed_query_compiler is not None
412+
assert (
413+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
414+
)
415+
416+
# create pandas dataframes
417+
native_df = df.to_pandas()
418+
native_result = native_df.groupby("A").apply(lambda x: x + 1)
419+
420+
# compare results
421+
assert_frame_equal(
422+
snow_result, native_result, check_dtype=False, check_index_type=False
423+
)
424+
425+
395426
@sql_count_checker(query_count=5)
396427
def test_iloc_head(session):
397428
# create tables

0 commit comments

Comments
 (0)