Skip to content

Commit b8bdce5

Browse files
SNOW-2643972: Add support for groupby properties (groupby.groups/indices) in faster pandas
1 parent db1b20f commit b8bdce5

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
- `cumsum`
124124
- `cummin`
125125
- `cummax`
126+
- `groupby.groups`
127+
- `groupby.indices`
126128
- Make faster pandas disabled by default (opt-in instead of opt-out).
127129
- Improve performance of `drop_duplicates` by avoiding joins when `keep!=False` in faster pandas.
128130

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6513,6 +6513,27 @@ def groupby_groups(
65136513
by: Any,
65146514
axis: int,
65156515
groupby_kwargs: dict[str, Any],
6516+
) -> PrettyDict[Hashable, "pd.Index"]:
6517+
"""
6518+
Wrapper around _groupby_groups_internal to be supported in faster pandas.
6519+
"""
6520+
if self._relaxed_query_compiler is not None:
6521+
return self._relaxed_query_compiler._groupby_groups_internal(
6522+
by=by,
6523+
axis=axis,
6524+
groupby_kwargs=groupby_kwargs,
6525+
)
6526+
return self._groupby_groups_internal(
6527+
by=by,
6528+
axis=axis,
6529+
groupby_kwargs=groupby_kwargs,
6530+
)
6531+
6532+
def _groupby_groups_internal(
6533+
self,
6534+
by: Any,
6535+
axis: int,
6536+
groupby_kwargs: dict[str, Any],
65166537
) -> PrettyDict[Hashable, "pd.Index"]:
65176538
"""
65186539
Get a PrettyDict mapping group keys to row labels.
@@ -6653,6 +6674,30 @@ def groupby_indices(
66536674
axis: int,
66546675
groupby_kwargs: dict[str, Any],
66556676
values_as_np_array: bool = True,
6677+
) -> dict[Hashable, np.ndarray]:
6678+
"""
6679+
Wrapper around _groupby_indices_internal to be supported in faster pandas.
6680+
"""
6681+
if self._relaxed_query_compiler is not None:
6682+
return self._relaxed_query_compiler._groupby_indices_internal(
6683+
by=by,
6684+
axis=axis,
6685+
groupby_kwargs=groupby_kwargs,
6686+
values_as_np_array=values_as_np_array,
6687+
)
6688+
return self._groupby_indices_internal(
6689+
by=by,
6690+
axis=axis,
6691+
groupby_kwargs=groupby_kwargs,
6692+
values_as_np_array=values_as_np_array,
6693+
)
6694+
6695+
def _groupby_indices_internal(
6696+
self,
6697+
by: Any,
6698+
axis: int,
6699+
groupby_kwargs: dict[str, Any],
6700+
values_as_np_array: bool = True,
66566701
) -> dict[Hashable, np.ndarray]:
66576702
"""
66586703
Get a dict mapping group keys to row labels.

tests/integ/modin/test_faster_pandas.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,36 @@ def test_groupby_apply(session):
619619
)
620620

621621

622+
@pytest.mark.parametrize("property_name", ["groups", "indices"])
623+
@sql_count_checker(query_count=3)
624+
def test_groupby_properties(session, property_name):
625+
with session_parameter_override(
626+
session, "dummy_row_pos_optimization_enabled", True
627+
):
628+
# create tables
629+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
630+
session.create_dataframe(
631+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
632+
).write.save_as_table(table_name, table_type="temp")
633+
634+
# create snow dataframes
635+
df = pd.read_snowflake(table_name).sort_values("B", ignore_index=True)
636+
snow_result = getattr(df.groupby("A"), property_name)
637+
638+
# verify that the input dataframe has a populated relaxed query compiler
639+
assert df._query_compiler._relaxed_query_compiler is not None
640+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
641+
642+
# create pandas dataframes
643+
native_df = df.to_pandas()
644+
native_result = getattr(native_df.groupby("A"), property_name)
645+
646+
# compare results
647+
snow_result = {k: list(v) for k, v in snow_result.items()}
648+
native_result = {k: list(v) for k, v in native_result.items()}
649+
assert snow_result == native_result
650+
651+
622652
@sql_count_checker(query_count=5)
623653
def test_iloc_head(session):
624654
with session_parameter_override(

0 commit comments

Comments
 (0)