Skip to content

Commit 62acef0

Browse files
SNOW-2643972: Add support for groupby properties (groupby.groups/indices) in faster pandas (#3984)
1 parent a66b95a commit 62acef0

File tree

3 files changed

+132
-3
lines changed

3 files changed

+132
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@
144144
- `cumsum`
145145
- `cummin`
146146
- `cummax`
147+
- `groupby.groups`
148+
- `groupby.indices`
147149
- `groupby.first`
148150
- `groupby.last`
149151
- `groupby.rank`

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,14 +823,17 @@ def __init__(self, frame: InternalFrame) -> None:
823823
storage_format = property(lambda self: "Snowflake")
824824

825825
def _raise_not_implemented_error_for_timedelta(
826-
self, frame: InternalFrame = None
826+
self, frame: InternalFrame = None, stack_depth: int = 2
827827
) -> None:
828828
"""Raise NotImplementedError for SnowflakeQueryCompiler methods which does not support timedelta yet."""
829829
if frame is None:
830830
frame = self._modin_frame
831831
for val in frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values():
832832
if isinstance(val, TimedeltaType):
833-
method = inspect.currentframe().f_back.f_back.f_code.co_name # type: ignore[union-attr]
833+
method_frame = inspect.currentframe()
834+
for _ in range(stack_depth):
835+
method_frame = method_frame.f_back # type: ignore[union-attr]
836+
method = method_frame.f_code.co_name # type: ignore[union-attr]
834837
ErrorMessage.not_implemented_for_timedelta(method)
835838

836839
def _warn_lost_snowpark_pandas_type(self) -> None:
@@ -5737,6 +5740,49 @@ def groupby_rank(
57375740
na_option: Literal["keep", "top", "bottom"] = "keep",
57385741
ascending: bool = True,
57395742
pct: bool = False,
5743+
) -> "SnowflakeQueryCompiler":
5744+
"""
5745+
Wrapper around _groupby_rank_internal to be supported in faster pandas.
5746+
"""
5747+
relaxed_query_compiler = None
5748+
if self._relaxed_query_compiler is not None:
5749+
relaxed_query_compiler = (
5750+
self._relaxed_query_compiler._groupby_rank_internal(
5751+
by=by,
5752+
groupby_kwargs=groupby_kwargs,
5753+
agg_args=agg_args,
5754+
agg_kwargs=agg_kwargs,
5755+
axis=axis,
5756+
method=method,
5757+
na_option=na_option,
5758+
ascending=ascending,
5759+
pct=pct,
5760+
)
5761+
)
5762+
qc = self._groupby_rank_internal(
5763+
by=by,
5764+
groupby_kwargs=groupby_kwargs,
5765+
agg_args=agg_args,
5766+
agg_kwargs=agg_kwargs,
5767+
axis=axis,
5768+
method=method,
5769+
na_option=na_option,
5770+
ascending=ascending,
5771+
pct=pct,
5772+
)
5773+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
5774+
5775+
def _groupby_rank_internal(
5776+
self,
5777+
by: Any,
5778+
groupby_kwargs: dict[str, Any],
5779+
agg_args: Any,
5780+
agg_kwargs: dict[str, Any],
5781+
axis: Axis = 0,
5782+
method: Literal["average", "min", "max", "first", "dense"] = "average",
5783+
na_option: Literal["keep", "top", "bottom"] = "keep",
5784+
ascending: bool = True,
5785+
pct: bool = False,
57405786
) -> "SnowflakeQueryCompiler":
57415787
"""
57425788
Compute groupby with rank.
@@ -6624,6 +6670,27 @@ def groupby_groups(
66246670
by: Any,
66256671
axis: int,
66266672
groupby_kwargs: dict[str, Any],
6673+
) -> PrettyDict[Hashable, "pd.Index"]:
6674+
"""
6675+
Wrapper around _groupby_groups_internal to be supported in faster pandas.
6676+
"""
6677+
if self._relaxed_query_compiler is not None:
6678+
return self._relaxed_query_compiler._groupby_groups_internal(
6679+
by=by,
6680+
axis=axis,
6681+
groupby_kwargs=groupby_kwargs,
6682+
)
6683+
return self._groupby_groups_internal(
6684+
by=by,
6685+
axis=axis,
6686+
groupby_kwargs=groupby_kwargs,
6687+
)
6688+
6689+
def _groupby_groups_internal(
6690+
self,
6691+
by: Any,
6692+
axis: int,
6693+
groupby_kwargs: dict[str, Any],
66276694
) -> PrettyDict[Hashable, "pd.Index"]:
66286695
"""
66296696
Get a PrettyDict mapping group keys to row labels.
@@ -6667,7 +6734,7 @@ def groupby_groups(
66676734
4 5 2 4 5
66686735
0 8 9 0 8
66696736
"""
6670-
self._raise_not_implemented_error_for_timedelta()
6737+
self._raise_not_implemented_error_for_timedelta(stack_depth=4)
66716738

66726739
original_index_names = self.get_index_names()
66736740
frame = self._modin_frame
@@ -6764,6 +6831,30 @@ def groupby_indices(
67646831
axis: int,
67656832
groupby_kwargs: dict[str, Any],
67666833
values_as_np_array: bool = True,
6834+
) -> dict[Hashable, np.ndarray]:
6835+
"""
6836+
Wrapper around _groupby_indices_internal to be supported in faster pandas.
6837+
"""
6838+
if self._relaxed_query_compiler is not None:
6839+
return self._relaxed_query_compiler._groupby_indices_internal(
6840+
by=by,
6841+
axis=axis,
6842+
groupby_kwargs=groupby_kwargs,
6843+
values_as_np_array=values_as_np_array,
6844+
)
6845+
return self._groupby_indices_internal(
6846+
by=by,
6847+
axis=axis,
6848+
groupby_kwargs=groupby_kwargs,
6849+
values_as_np_array=values_as_np_array,
6850+
)
6851+
6852+
def _groupby_indices_internal(
6853+
self,
6854+
by: Any,
6855+
axis: int,
6856+
groupby_kwargs: dict[str, Any],
6857+
values_as_np_array: bool = True,
67676858
) -> dict[Hashable, np.ndarray]:
67686859
"""
67696860
Get a dict mapping group keys to row labels.

tests/integ/modin/test_faster_pandas.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,12 @@ def test_groupby_no_param_functions(session, func):
300300
# verify that the input dataframe has a populated relaxed query compiler
301301
assert df._query_compiler._relaxed_query_compiler is not None
302302
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
303+
# verify that the output dataframe also has a populated relaxed query compiler
304+
assert snow_result._query_compiler._relaxed_query_compiler is not None
305+
assert (
306+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode
307+
is True
308+
)
303309

304310
# create pandas dataframes
305311
native_df = df.to_pandas()
@@ -662,6 +668,36 @@ def test_groupby_apply(session):
662668
)
663669

664670

671+
@pytest.mark.parametrize("property_name", ["groups", "indices"])
672+
@sql_count_checker(query_count=3)
673+
def test_groupby_properties(session, property_name):
674+
with session_parameter_override(
675+
session, "dummy_row_pos_optimization_enabled", True
676+
):
677+
# create tables
678+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
679+
session.create_dataframe(
680+
native_pd.DataFrame([[2, 12], [2, 11], [3, 13]], columns=["A", "B"])
681+
).write.save_as_table(table_name, table_type="temp")
682+
683+
# create snow dataframes
684+
df = pd.read_snowflake(table_name).sort_values("B", ignore_index=True)
685+
snow_result = getattr(df.groupby("A"), property_name)
686+
687+
# verify that the input dataframe has a populated relaxed query compiler
688+
assert df._query_compiler._relaxed_query_compiler is not None
689+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
690+
691+
# create pandas dataframes
692+
native_df = df.to_pandas()
693+
native_result = getattr(native_df.groupby("A"), property_name)
694+
695+
# compare results
696+
snow_result = {k: list(v) for k, v in snow_result.items()}
697+
native_result = {k: list(v) for k, v in native_result.items()}
698+
assert snow_result == native_result
699+
700+
665701
@sql_count_checker(query_count=5)
666702
def test_iloc_head(session):
667703
with session_parameter_override(

0 commit comments

Comments
 (0)