Skip to content

Commit b61f672

Browse files
SNOW-2444072: Add support for groupby.apply in faster pandas (2nd attempt) (#3962)
1 parent c5fdf7a commit b61f672

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#### Improvements
3434

3535
- Add support for the following in faster pandas:
36+
- `groupby.apply`
3637
- `groupby.nunique`
3738
- `groupby.size`
3839
- `concat`

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4898,6 +4898,52 @@ def convert_func_to_agg_func_info(
48984898
return query_compiler if as_index else query_compiler.reset_index(drop=drop)
48994899

49004900
def groupby_apply(
4901+
self,
4902+
by: Any,
4903+
agg_func: AggFuncType,
4904+
axis: int,
4905+
groupby_kwargs: dict[str, Any],
4906+
agg_args: Any,
4907+
agg_kwargs: dict[str, Any],
4908+
series_groupby: bool,
4909+
include_groups: bool,
4910+
force_single_group: bool = False,
4911+
force_list_like_to_series: bool = False,
4912+
) -> "SnowflakeQueryCompiler":
4913+
"""
4914+
Wrapper around _groupby_apply_internal to be supported in faster pandas.
4915+
"""
4916+
relaxed_query_compiler = None
4917+
if self._relaxed_query_compiler is not None:
4918+
relaxed_query_compiler = (
4919+
self._relaxed_query_compiler._groupby_apply_internal(
4920+
by=by,
4921+
agg_func=agg_func,
4922+
axis=axis,
4923+
groupby_kwargs=groupby_kwargs,
4924+
agg_args=agg_args,
4925+
agg_kwargs=agg_kwargs,
4926+
series_groupby=series_groupby,
4927+
include_groups=include_groups,
4928+
force_single_group=force_single_group,
4929+
force_list_like_to_series=force_list_like_to_series,
4930+
)
4931+
)
4932+
qc = self._groupby_apply_internal(
4933+
by=by,
4934+
agg_func=agg_func,
4935+
axis=axis,
4936+
groupby_kwargs=groupby_kwargs,
4937+
agg_args=agg_args,
4938+
agg_kwargs=agg_kwargs,
4939+
series_groupby=series_groupby,
4940+
include_groups=include_groups,
4941+
force_single_group=force_single_group,
4942+
force_list_like_to_series=force_list_like_to_series,
4943+
)
4944+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4945+
4946+
def _groupby_apply_internal(
49014947
self,
49024948
by: Any,
49034949
agg_func: Callable,

tests/integ/modin/test_faster_pandas.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,41 @@ def test_groupby_agg(session, func):
504504
assert_frame_equal(snow_result4, native_result4, check_dtype=False)
505505

506506

507+
@sql_count_checker(query_count=9, join_count=1, udtf_count=1)
508+
def test_groupby_apply(session):
509+
with session_parameter_override(
510+
session, "dummy_row_pos_optimization_enabled", True
511+
):
512+
# create tables
513+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
514+
session.create_dataframe(
515+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
516+
).write.save_as_table(table_name, table_type="temp")
517+
518+
# create snow dataframes
519+
df = pd.read_snowflake(table_name).sort_values("B", ignore_index=True)
520+
snow_result = df.groupby("A").apply(lambda x: x + 1)
521+
522+
# verify that the input dataframe has a populated relaxed query compiler
523+
assert df._query_compiler._relaxed_query_compiler is not None
524+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
525+
# verify that the output dataframe also has a populated relaxed query compiler
526+
assert snow_result._query_compiler._relaxed_query_compiler is not None
527+
assert (
528+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode
529+
is True
530+
)
531+
532+
# create pandas dataframes
533+
native_df = df.to_pandas()
534+
native_result = native_df.groupby("A").apply(lambda x: x + 1)
535+
536+
# compare results
537+
assert_frame_equal(
538+
snow_result, native_result, check_dtype=False, check_index_type=False
539+
)
540+
541+
507542
@sql_count_checker(query_count=5)
508543
def test_iloc_head(session):
509544
with session_parameter_override(

0 commit comments

Comments
 (0)