Skip to content

Commit b44816c

Browse files
SNOW-2435290: Add support for groupby.agg/min/max/count/sum/mean/median/std/var in faster pandas
1 parent 85466f9 commit b44816c

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@
122122
- `drop`
123123
- `invert`
124124
- `duplicated`
125+
- `groupby.agg`
126+
- `groupby.min`
127+
- `groupby.max`
128+
- `groupby.count`
129+
- `groupby.sum`
130+
- `groupby.mean`
131+
- `groupby.median`
132+
- `groupby.std`
133+
- `groupby.var`
134+
125135
- Reuse row count from the relaxed query compiler in `get_axis_len`.
126136

127137
#### Bug Fixes

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4501,6 +4501,50 @@ def groupby_agg(
45014501
numeric_only: bool = False,
45024502
is_series_groupby: bool = False,
45034503
drop: bool = False,
4504+
) -> "SnowflakeQueryCompiler":
4505+
"""
4506+
Wrapper around _groupby_agg_internal to be supported in faster pandas.
4507+
"""
4508+
relaxed_query_compiler = None
4509+
if self._relaxed_query_compiler is not None:
4510+
relaxed_query_compiler = self._relaxed_query_compiler._groupby_agg_internal(
4511+
by=by,
4512+
agg_func=agg_func,
4513+
axis=axis,
4514+
groupby_kwargs=groupby_kwargs,
4515+
agg_args=agg_args,
4516+
agg_kwargs=agg_kwargs,
4517+
how=how,
4518+
numeric_only=numeric_only,
4519+
is_series_groupby=is_series_groupby,
4520+
drop=drop,
4521+
)
4522+
qc = self._groupby_agg_internal(
4523+
by=by,
4524+
agg_func=agg_func,
4525+
axis=axis,
4526+
groupby_kwargs=groupby_kwargs,
4527+
agg_args=agg_args,
4528+
agg_kwargs=agg_kwargs,
4529+
how=how,
4530+
numeric_only=numeric_only,
4531+
is_series_groupby=is_series_groupby,
4532+
drop=drop,
4533+
)
4534+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4535+
4536+
def _groupby_agg_internal(
4537+
self,
4538+
by: Any,
4539+
agg_func: AggFuncType,
4540+
axis: int,
4541+
groupby_kwargs: dict[str, Any],
4542+
agg_args: Any,
4543+
agg_kwargs: dict[str, Any],
4544+
how: str = "axis_wise",
4545+
numeric_only: bool = False,
4546+
is_series_groupby: bool = False,
4547+
drop: bool = False,
45044548
) -> "SnowflakeQueryCompiler":
45054549
"""
45064550
compute groupby with aggregation functions.

tests/integ/modin/test_faster_pandas.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,61 @@ def test_duplicated(session):
252252
assert_series_equal(snow_result, native_result)
253253

254254

255+
@pytest.mark.parametrize(
256+
"func",
257+
[
258+
"min",
259+
"max",
260+
"count",
261+
"sum",
262+
"mean",
263+
"median",
264+
"std",
265+
"var",
266+
],
267+
)
268+
@sql_count_checker(query_count=6)
269+
def test_groupby_agg(session, func):
270+
# create tables
271+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
272+
session.create_dataframe(
273+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
274+
).write.save_as_table(table_name, table_type="temp")
275+
276+
# create snow dataframes
277+
df = pd.read_snowflake(table_name)
278+
snow_result1 = getattr(df.groupby("A"), func)()
279+
snow_result2 = df.groupby("A").agg([func])
280+
snow_result3 = getattr(df.groupby("A")["B"], func)()
281+
snow_result4 = df.groupby("A")["B"].agg([func])
282+
283+
# verify that the input dataframe has a populated relaxed query compiler
284+
assert df._query_compiler._relaxed_query_compiler is not None
285+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
286+
# verify that the output dataframe also has a populated relaxed query compiler
287+
assert snow_result1._query_compiler._relaxed_query_compiler is not None
288+
assert (
289+
snow_result1._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
290+
)
291+
assert snow_result2._query_compiler._relaxed_query_compiler is not None
292+
assert (
293+
snow_result2._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
294+
)
295+
296+
# create pandas dataframes
297+
native_df = df.to_pandas()
298+
native_result1 = getattr(native_df.groupby("A"), func)()
299+
native_result2 = native_df.groupby("A").agg([func])
300+
native_result3 = getattr(native_df.groupby("A")["B"], func)()
301+
native_result4 = native_df.groupby("A")["B"].agg([func])
302+
303+
# compare results
304+
assert_frame_equal(snow_result1, native_result1, check_dtype=False)
305+
assert_frame_equal(snow_result2, native_result2, check_dtype=False)
306+
assert_series_equal(snow_result3, native_result3, check_dtype=False)
307+
assert_frame_equal(snow_result4, native_result4, check_dtype=False)
308+
309+
255310
@sql_count_checker(query_count=3)
256311
def test_invert(session):
257312
# create tables

0 commit comments

Comments
 (0)