Skip to content

Commit aad530a

Browse files
SNOW-2435290: Add support for groupby.agg/min/max/count/sum/mean/median/std/var in faster pandas (#3908)
1 parent 385d069 commit aad530a

File tree

4 files changed

+117
-9
lines changed

4 files changed

+117
-9
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@
134134
- `median`
135135
- `std`
136136
- `var`
137+
- `groupby.agg`
138+
- `groupby.min`
139+
- `groupby.max`
140+
- `groupby.count`
141+
- `groupby.sum`
142+
- `groupby.mean`
143+
- `groupby.median`
144+
- `groupby.std`
145+
- `groupby.var`
137146
- Reuse row count from the relaxed query compiler in `get_axis_len`.
138147

139148
#### Bug Fixes

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4515,6 +4515,50 @@ def groupby_agg(
45154515
numeric_only: bool = False,
45164516
is_series_groupby: bool = False,
45174517
drop: bool = False,
4518+
) -> "SnowflakeQueryCompiler":
4519+
"""
4520+
Wrapper around _groupby_agg_internal to be supported in faster pandas.
4521+
"""
4522+
relaxed_query_compiler = None
4523+
if self._relaxed_query_compiler is not None:
4524+
relaxed_query_compiler = self._relaxed_query_compiler._groupby_agg_internal(
4525+
by=by,
4526+
agg_func=agg_func,
4527+
axis=axis,
4528+
groupby_kwargs=groupby_kwargs,
4529+
agg_args=agg_args,
4530+
agg_kwargs=agg_kwargs,
4531+
how=how,
4532+
numeric_only=numeric_only,
4533+
is_series_groupby=is_series_groupby,
4534+
drop=drop,
4535+
)
4536+
qc = self._groupby_agg_internal(
4537+
by=by,
4538+
agg_func=agg_func,
4539+
axis=axis,
4540+
groupby_kwargs=groupby_kwargs,
4541+
agg_args=agg_args,
4542+
agg_kwargs=agg_kwargs,
4543+
how=how,
4544+
numeric_only=numeric_only,
4545+
is_series_groupby=is_series_groupby,
4546+
drop=drop,
4547+
)
4548+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4549+
4550+
def _groupby_agg_internal(
4551+
self,
4552+
by: Any,
4553+
agg_func: AggFuncType,
4554+
axis: int,
4555+
groupby_kwargs: dict[str, Any],
4556+
agg_args: Any,
4557+
agg_kwargs: dict[str, Any],
4558+
how: str = "axis_wise",
4559+
numeric_only: bool = False,
4560+
is_series_groupby: bool = False,
4561+
drop: bool = False,
45184562
) -> "SnowflakeQueryCompiler":
45194563
"""
45204564
compute groupby with aggregation functions.

src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,7 +2642,7 @@ def seconds():
26422642
0 1
26432643
1 2
26442644
2 3
2645-
dtype: int8
2645+
dtype: int64
26462646
26472647
For TimedeltaIndex:
26482648
@@ -2702,7 +2702,7 @@ def microseconds():
27022702
0 1
27032703
1 2
27042704
2 3
2705-
dtype: int8
2705+
dtype: int64
27062706
27072707
For TimedeltaIndex:
27082708
@@ -2734,7 +2734,7 @@ def nanoseconds():
27342734
0 1
27352735
1 2
27362736
2 3
2737-
dtype: int8
2737+
dtype: int64
27382738
27392739
For TimedeltaIndex:
27402740

tests/integ/modin/test_faster_pandas.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ def test_read_filter_join_on_index(session):
124124
)
125125

126126

127-
@sql_count_checker(query_count=3)
128-
def test_read_filter_groupby_agg(session):
129-
# test a chain of operations that are not fully supported in faster pandas
127+
@sql_count_checker(query_count=3, join_count=2)
128+
def test_read_filter_iloc_index(session):
129+
# test a chain of operations that are not yet fully supported in faster pandas
130130

131131
# create tables
132132
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
@@ -136,19 +136,19 @@ def test_read_filter_groupby_agg(session):
136136

137137
# create snow dataframes
138138
df = pd.read_snowflake(table_name)
139-
snow_result = df[df["B"] > 11].groupby("A").min()
139+
snow_result = df.iloc[[1], :]
140140

141141
# verify that the input dataframe has a populated relaxed query compiler
142142
assert df._query_compiler._relaxed_query_compiler is not None
143143
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
144144
# verify that the output dataframe has an empty relaxed query compiler
145-
# because groupby() and min() are not supported in faster pandas yet
145+
# because iloc for index is not supported in faster pandas yet
146146
assert snow_result._query_compiler._relaxed_query_compiler is None
147147
assert snow_result._query_compiler._dummy_row_pos_mode is False
148148

149149
# create pandas dataframes
150150
native_df = df.to_pandas()
151-
native_result = native_df[native_df["B"] > 11].groupby("A").min()
151+
native_result = native_df.iloc[[1], :]
152152

153153
# compare results
154154
assert_frame_equal(snow_result, native_result)
@@ -308,6 +308,61 @@ def test_duplicated(session):
308308
assert_series_equal(snow_result, native_result)
309309

310310

311+
@pytest.mark.parametrize(
312+
"func",
313+
[
314+
"min",
315+
"max",
316+
"count",
317+
"sum",
318+
"mean",
319+
"median",
320+
"std",
321+
"var",
322+
],
323+
)
324+
@sql_count_checker(query_count=6)
325+
def test_groupby_agg(session, func):
326+
# create tables
327+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
328+
session.create_dataframe(
329+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
330+
).write.save_as_table(table_name, table_type="temp")
331+
332+
# create snow dataframes
333+
df = pd.read_snowflake(table_name)
334+
snow_result1 = getattr(df.groupby("A"), func)()
335+
snow_result2 = df.groupby("A").agg([func])
336+
snow_result3 = getattr(df.groupby("A")["B"], func)()
337+
snow_result4 = df.groupby("A")["B"].agg([func])
338+
339+
# verify that the input dataframe has a populated relaxed query compiler
340+
assert df._query_compiler._relaxed_query_compiler is not None
341+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
342+
# verify that the output dataframe also has a populated relaxed query compiler
343+
assert snow_result1._query_compiler._relaxed_query_compiler is not None
344+
assert (
345+
snow_result1._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
346+
)
347+
assert snow_result2._query_compiler._relaxed_query_compiler is not None
348+
assert (
349+
snow_result2._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
350+
)
351+
352+
# create pandas dataframes
353+
native_df = df.to_pandas()
354+
native_result1 = getattr(native_df.groupby("A"), func)()
355+
native_result2 = native_df.groupby("A").agg([func])
356+
native_result3 = getattr(native_df.groupby("A")["B"], func)()
357+
native_result4 = native_df.groupby("A")["B"].agg([func])
358+
359+
# compare results
360+
assert_frame_equal(snow_result1, native_result1, check_dtype=False)
361+
assert_frame_equal(snow_result2, native_result2, check_dtype=False)
362+
assert_series_equal(snow_result3, native_result3, check_dtype=False)
363+
assert_frame_equal(snow_result4, native_result4, check_dtype=False)
364+
365+
311366
@sql_count_checker(query_count=5)
312367
def test_iloc_head(session):
313368
# create tables

0 commit comments

Comments
 (0)