Skip to content

Commit 103e55d

Browse files
SNOW-2441444: Add support for groupby.nunique/size in faster pandas (#3923)
1 parent c09a805 commit 103e55d

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@
152152
- `groupby.median`
153153
- `groupby.std`
154154
- `groupby.var`
155+
- `groupby.nunique`
156+
- `groupby.size`
155157
- `drop_duplicates`
156158
- Reuse row count from the relaxed query compiler in `get_axis_len`.
157159

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6078,6 +6078,43 @@ def groupby_size(
60786078
agg_kwargs: dict[str, Any],
60796079
drop: bool = False,
60806080
**kwargs: dict[str, Any],
6081+
) -> "SnowflakeQueryCompiler":
6082+
"""
6083+
Wrapper around _groupby_size_internal to be supported in faster pandas.
6084+
"""
6085+
relaxed_query_compiler = None
6086+
if self._relaxed_query_compiler is not None:
6087+
relaxed_query_compiler = (
6088+
self._relaxed_query_compiler._groupby_size_internal(
6089+
by=by,
6090+
axis=axis,
6091+
groupby_kwargs=groupby_kwargs,
6092+
agg_args=agg_args,
6093+
agg_kwargs=agg_kwargs,
6094+
drop=drop,
6095+
**kwargs,
6096+
)
6097+
)
6098+
qc = self._groupby_size_internal(
6099+
by=by,
6100+
axis=axis,
6101+
groupby_kwargs=groupby_kwargs,
6102+
agg_args=agg_args,
6103+
agg_kwargs=agg_kwargs,
6104+
drop=drop,
6105+
**kwargs,
6106+
)
6107+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
6108+
6109+
def _groupby_size_internal(
6110+
self,
6111+
by: Any,
6112+
axis: int,
6113+
groupby_kwargs: dict[str, Any],
6114+
agg_args: tuple[Any],
6115+
agg_kwargs: dict[str, Any],
6116+
drop: bool = False,
6117+
**kwargs: dict[str, Any],
60816118
) -> "SnowflakeQueryCompiler":
60826119
"""
60836120
compute groupby with size.
@@ -6495,6 +6532,43 @@ def groupby_nunique(
64956532
agg_kwargs: dict[str, Any],
64966533
drop: bool = False,
64976534
**kwargs: Any,
6535+
) -> "SnowflakeQueryCompiler":
6536+
"""
6537+
Wrapper around _groupby_nunique_internal to be supported in faster pandas.
6538+
"""
6539+
relaxed_query_compiler = None
6540+
if self._relaxed_query_compiler is not None:
6541+
relaxed_query_compiler = (
6542+
self._relaxed_query_compiler._groupby_nunique_internal(
6543+
by=by,
6544+
axis=axis,
6545+
groupby_kwargs=groupby_kwargs,
6546+
agg_args=agg_args,
6547+
agg_kwargs=agg_kwargs,
6548+
drop=drop,
6549+
**kwargs,
6550+
)
6551+
)
6552+
qc = self._groupby_nunique_internal(
6553+
by=by,
6554+
axis=axis,
6555+
groupby_kwargs=groupby_kwargs,
6556+
agg_args=agg_args,
6557+
agg_kwargs=agg_kwargs,
6558+
drop=drop,
6559+
**kwargs,
6560+
)
6561+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
6562+
6563+
def _groupby_nunique_internal(
6564+
self,
6565+
by: Any,
6566+
axis: int,
6567+
groupby_kwargs: dict[str, Any],
6568+
agg_args: Any,
6569+
agg_kwargs: dict[str, Any],
6570+
drop: bool = False,
6571+
**kwargs: Any,
64986572
) -> "SnowflakeQueryCompiler":
64996573
# We have to override the Modin version of this function because our groupby frontend passes the
65006574
# ignored numeric_only argument to this query compiler method, and BaseQueryCompiler

tests/integ/modin/test_faster_pandas.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ def test_duplicated(session):
348348
"median",
349349
"std",
350350
"var",
351+
"nunique",
352+
"size",
351353
],
352354
)
353355
@sql_count_checker(query_count=6)
@@ -386,7 +388,10 @@ def test_groupby_agg(session, func):
386388
native_result4 = native_df.groupby("A")["B"].agg([func])
387389

388390
# compare results
389-
assert_frame_equal(snow_result1, native_result1, check_dtype=False)
391+
if func == "size":
392+
assert_series_equal(snow_result1, native_result1, check_dtype=False)
393+
else:
394+
assert_frame_equal(snow_result1, native_result1, check_dtype=False)
390395
assert_frame_equal(snow_result2, native_result2, check_dtype=False)
391396
assert_series_equal(snow_result3, native_result3, check_dtype=False)
392397
assert_frame_equal(snow_result4, native_result4, check_dtype=False)

0 commit comments

Comments
 (0)