Skip to content

Commit 7412bf4

Browse files
SNOW-2359402 Enable autoswitching on DataFrameGroupBy (#3936)
1 parent df6da5d commit 7412bf4

File tree

12 files changed

+516
-51
lines changed

12 files changed

+516
-51
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@
188188
- `skew()` with `axis=1` or `numeric_only=False` parameters
189189
- `round()` with `decimals` parameter as a Series
190190
- `corr()` with `method!=pearson` parameter
191+
- `df.groupby()` with `axis=1`, `by!=None and level!=None`, or by containing any non-pandas hashable labels.
192+
- `groupby_fillna()` with `downcast` parameter
193+
- `groupby_first()` with `min_count>1`
194+
- `groupby_last()` with `min_count>1`
195+
- `shift()` with `freq` parameter
191196
- Set `cte_optimization_enabled` to True for all Snowpark pandas sessions.
192197
- Add support for the following in faster pandas:
193198
- `isin`

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@
450450
# For now, limit number of quantiles supported df.quantiles to avoid producing recursion limit failure in Snowpark.
451451
MAX_QUANTILES_SUPPORTED: int = 16
452452

453-
_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE = "does not yet support axis == 1, by != None and level != None, or by containing any non-pandas hashable labels."
453+
_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE = "does not yet support axis == 1, by != None and level != None, or by containing any non-pandas hashable labels"
454454

455455
QUARTER_START_MONTHS = [1, 4, 7, 10]
456456

@@ -1153,8 +1153,7 @@ def stay_cost(
11531153
return QCCoercionCost.COST_IMPOSSIBLE
11541154

11551155
if method_key in HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS:
1156-
1157-
if arguments and SnowflakeQueryCompiler._has_unsupported_args(
1156+
if SnowflakeQueryCompiler._has_unsupported_args(
11581157
api_cls_name, operation, arguments
11591158
):
11601159
WarningMessage.single_warning(
@@ -4546,7 +4545,7 @@ def groupby_ngroups(
45464545
is_supported = check_is_groupby_supported_by_snowflake(by, level, axis)
45474546
if not is_supported:
45484547
ErrorMessage.not_implemented(
4549-
f"Snowpark pandas GroupBy.ngroups {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
4548+
f"Snowpark pandas GroupBy.ngroups {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
45504549
)
45514550

45524551
query_compiler = get_frame_with_groupby_columns_as_index(
@@ -4555,7 +4554,7 @@ def groupby_ngroups(
45554554

45564555
if query_compiler is None:
45574556
ErrorMessage.not_implemented(
4558-
f"Snowpark pandas GroupBy.ngroups {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
4557+
f"Snowpark pandas GroupBy.ngroups {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
45594558
)
45604559

45614560
internal_frame = query_compiler._modin_frame
@@ -4706,7 +4705,7 @@ def _groupby_agg_internal(
47064705
by, level, axis
47074706
):
47084707
ErrorMessage.not_implemented(
4709-
f"Snowpark pandas GroupBy.aggregate {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
4708+
f"Snowpark pandas GroupBy.aggregate {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
47104709
)
47114710

47124711
by_list = query_compiler._modin_frame.index_column_pandas_labels
@@ -4961,6 +4960,7 @@ def groupby_apply(
49614960
include_groups: bool,
49624961
force_single_group: bool = False,
49634962
force_list_like_to_series: bool = False,
4963+
is_transform: bool = False,
49644964
) -> "SnowflakeQueryCompiler":
49654965
"""
49664966
Wrapper around _groupby_apply_internal to be supported in faster pandas.
@@ -4979,6 +4979,7 @@ def groupby_apply(
49794979
include_groups=include_groups,
49804980
force_single_group=force_single_group,
49814981
force_list_like_to_series=force_list_like_to_series,
4982+
is_transform=is_transform,
49824983
)
49834984
)
49844985
qc = self._groupby_apply_internal(
@@ -4992,6 +4993,7 @@ def groupby_apply(
49924993
include_groups=include_groups,
49934994
force_single_group=force_single_group,
49944995
force_list_like_to_series=force_list_like_to_series,
4996+
is_transform=is_transform,
49954997
)
49964998
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
49974999

@@ -5007,6 +5009,7 @@ def _groupby_apply_internal(
50075009
include_groups: bool,
50085010
force_single_group: bool = False,
50095011
force_list_like_to_series: bool = False,
5012+
is_transform: bool = False,
50105013
) -> "SnowflakeQueryCompiler":
50115014
"""
50125015
Group according to `by` and `level`, apply a function to each group, and combine the results.
@@ -5114,7 +5117,7 @@ def _groupby_apply_internal(
51145117
data_columns_index = _modin_frame.data_columns_index[
51155118
input_data_column_positions
51165119
]
5117-
is_transform = groupby_kwargs.get("apply_op") == "transform"
5120+
51185121
output_schema, udtf = create_udtf_for_groupby_apply(
51195122
agg_func,
51205123
agg_args,
@@ -5499,7 +5502,7 @@ def _groupby_first_last(
54995502
is_supported = check_is_groupby_supported_by_snowflake(by, level, axis)
55005503
if not is_supported:
55015504
ErrorMessage.not_implemented(
5502-
f"Snowpark pandas GroupBy.{method} {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
5505+
f"Snowpark pandas GroupBy.{method} {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
55035506
)
55045507
# TODO: Support groupby first and last with min_count (SNOW-1482931)
55055508
if agg_kwargs.get("min_count", -1) > 1:
@@ -5544,6 +5547,19 @@ def _groupby_first_last(
55445547
result = result.reset_index(drop=False)
55455548
return result
55465549

5550+
@register_query_compiler_method_not_implemented(
5551+
"DataFrameGroupBy",
5552+
"first",
5553+
UnsupportedArgsRule(
5554+
unsupported_conditions=[
5555+
(
5556+
lambda args: args.get("min_count", -1) > 1
5557+
or args.get("agg_kwargs", {}).get("min_count", -1) > 1,
5558+
"GroupBy.first does not yet support min_count > 1",
5559+
),
5560+
],
5561+
),
5562+
)
55475563
def groupby_first(
55485564
self,
55495565
by: Any,
@@ -5577,6 +5593,19 @@ def groupby_first(
55775593
"first", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs
55785594
)
55795595

5596+
@register_query_compiler_method_not_implemented(
5597+
"DataFrameGroupBy",
5598+
"last",
5599+
UnsupportedArgsRule(
5600+
unsupported_conditions=[
5601+
(
5602+
lambda args: args.get("agg_kwargs", {}).get("min_count", -1) > 1
5603+
or args.get("min_count", -1) > 1,
5604+
"GroupBy.last does not yet support min_count > 1",
5605+
),
5606+
],
5607+
),
5608+
)
55805609
def groupby_last(
55815610
self,
55825611
by: Any,
@@ -5610,6 +5639,19 @@ def groupby_last(
56105639
"last", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs
56115640
)
56125641

5642+
@register_query_compiler_method_not_implemented(
5643+
"DataFrameGroupBy",
5644+
"rank",
5645+
UnsupportedArgsRule(
5646+
unsupported_conditions=[
5647+
(
5648+
lambda args: args.get("groupby_kwargs", {}).get("level") is not None
5649+
and args.get("groupby_kwargs", {}).get("level") != 0,
5650+
"GroupBy.rank with level != 0 is not supported yet in Snowpark pandas.",
5651+
),
5652+
],
5653+
),
5654+
)
56135655
def groupby_rank(
56145656
self,
56155657
by: Any,
@@ -6059,6 +6101,23 @@ def groupby_rolling(
60596101
result_qc = SnowflakeQueryCompiler(new_frame)
60606102
return result_qc
60616103

6104+
@register_query_compiler_method_not_implemented(
6105+
"DataFrameGroupBy",
6106+
"shift",
6107+
UnsupportedArgsRule(
6108+
unsupported_conditions=[
6109+
(
6110+
lambda args: args.get("freq") is not None,
6111+
"'freq' argument is not supported yet in Snowpark pandas",
6112+
),
6113+
(
6114+
lambda args: args.get("groupby_kwargs", {}).get("level") is not None
6115+
and args.get("groupby_kwargs", {}).get("level") != 0,
6116+
"GroupBy.shift with level != 0 is not supported yet in Snowpark pandas",
6117+
),
6118+
],
6119+
),
6120+
)
60626121
def groupby_shift(
60636122
self,
60646123
by: Any,
@@ -6314,7 +6373,7 @@ def groupby_get_group(
63146373
is_supported = check_is_groupby_supported_by_snowflake(by, level, axis)
63156374
if not is_supported: # pragma: no cover
63166375
ErrorMessage.not_implemented(
6317-
f"Snowpark pandas GroupBy.get_group {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
6376+
f"Snowpark pandas GroupBy.get_group {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
63186377
)
63196378
if is_list_like(by):
63206379
ErrorMessage.not_implemented(
@@ -6418,7 +6477,7 @@ def _groupby_size_internal(
64186477
is_supported = check_is_groupby_supported_by_snowflake(by, level, axis)
64196478
if not is_supported:
64206479
ErrorMessage.not_implemented(
6421-
f"Snowpark pandas GroupBy.size {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
6480+
f"Snowpark pandas GroupBy.size {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
64226481
)
64236482
if not is_list_like(by):
64246483
by = [by]
@@ -6717,7 +6776,7 @@ def groupby_cummin(
67176776
self,
67186777
by: Any,
67196778
axis: int,
6720-
numeric_only: int,
6779+
numeric_only: bool,
67216780
groupby_kwargs: dict[str, Any],
67226781
) -> "SnowflakeQueryCompiler":
67236782
"""
@@ -6910,7 +6969,7 @@ def groupby_value_counts(
69106969
is_supported = check_is_groupby_supported_by_snowflake(by, level, axis)
69116970
if not is_supported:
69126971
ErrorMessage.not_implemented(
6913-
f"Snowpark pandas GroupBy.value_counts {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
6972+
f"Snowpark pandas GroupBy.value_counts {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
69146973
)
69156974
if bins is not None:
69166975
raise ErrorMessage.not_implemented("bins argument is not yet supported")
@@ -7047,6 +7106,18 @@ def groupby_value_counts(
70477106
ignore_index=not as_index, # When as_index=False, take the default positional index
70487107
)
70497108

7109+
@register_query_compiler_method_not_implemented(
7110+
"DataFrameGroupBy",
7111+
"fillna",
7112+
UnsupportedArgsRule(
7113+
unsupported_conditions=[
7114+
(
7115+
lambda args: args.get("downcast") is not None,
7116+
"'downcast' argument is not supported yet in Snowpark pandas",
7117+
),
7118+
],
7119+
),
7120+
)
70507121
def groupby_fillna(
70517122
self,
70527123
by: Any,
@@ -7089,7 +7160,7 @@ def groupby_fillna(
70897160
)
70907161
if not is_supported:
70917162
ErrorMessage.not_implemented(
7092-
f"Snowpark pandas GroupBy.fillna {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
7163+
f"Snowpark pandas GroupBy.fillna {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
70937164
)
70947165

70957166
if by is not None and not is_list_like(by):
@@ -7353,7 +7424,7 @@ def groupby_pct_change(
73537424
# Remaining parameters are validated in pct_change method
73547425
if not is_supported:
73557426
ErrorMessage.not_implemented(
7356-
f"Snowpark pandas GroupBy.pct_change {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
7427+
f"Snowpark pandas GroupBy.pct_change {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}."
73577428
)
73587429

73597430
by_labels = by

src/snowflake/snowpark/modin/plugin/extensions/dataframe_groupby_overrides.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,14 @@
5050
from snowflake.snowpark.modin.plugin._internal.apply_utils import (
5151
create_groupby_transform_func,
5252
)
53+
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
54+
check_is_groupby_supported_by_snowflake,
55+
)
5356
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
5457
SnowflakeQueryCompiler,
58+
UnsupportedArgsRule,
59+
_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE,
60+
register_query_compiler_method_not_implemented,
5561
)
5662

5763
# the following import is used in doctests
@@ -78,6 +84,22 @@
7884

7985

8086
@register_df_groupby_override("__init__")
87+
@register_query_compiler_method_not_implemented(
88+
"DataFrameGroupBy",
89+
"__init__",
90+
UnsupportedArgsRule(
91+
unsupported_conditions=[
92+
(
93+
lambda args: not check_is_groupby_supported_by_snowflake(
94+
args.get("by"),
95+
args.get("level"),
96+
args.get("axis", 0),
97+
),
98+
f"Groupby {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}",
99+
)
100+
]
101+
),
102+
)
81103
def __init__(
82104
self,
83105
df,
@@ -114,9 +136,6 @@ def __init__(
114136
"group_keys": group_keys,
115137
}
116138
self._kwargs.update(kwargs)
117-
if "apply_op" not in self._kwargs:
118-
# Can be "apply", "transform", "filter" or "aggregate"
119-
self._kwargs.update({"apply_op": "apply"})
120139

121140

122141
@register_df_groupby_override("ngroups")
@@ -172,7 +191,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
172191

173192

174193
@register_df_groupby_override("apply")
175-
def apply(self, func, *args, include_groups=True, **kwargs):
194+
def apply(self, func, *args, include_groups=True, _is_transform=False, **kwargs):
176195
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
177196
# TODO: SNOW-1244717: Explore whether window function are performant and can be used
178197
# whenever `func` is an aggregation function.
@@ -188,6 +207,7 @@ def apply(self, func, *args, include_groups=True, **kwargs):
188207
agg_kwargs=kwargs,
189208
series_groupby=False,
190209
include_groups=include_groups,
210+
is_transform=_is_transform,
191211
)
192212
)
193213
if dataframe_result.columns.equals(pandas.Index([MODIN_UNNAMED_SERIES_LABEL])):
@@ -320,11 +340,10 @@ def transform(
320340
dropna=False,
321341
sort=self._sort,
322342
)
323-
groupby_obj._kwargs["apply_op"] = "transform"
324-
325343
# Apply the transform function to each group.
326344
res = groupby_obj.apply(
327-
create_groupby_transform_func(func, by, level, *args, **kwargs)
345+
create_groupby_transform_func(func, by, level, *args, **kwargs),
346+
_is_transform=True,
328347
)
329348

330349
dropna = self._kwargs.get("dropna", True)

src/snowflake/snowpark/modin/plugin/extensions/series_groupby_overrides.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_group(self, name, obj=None):
101101

102102

103103
@register_ser_groupby_override("apply")
104-
def apply(self, func, *args, include_groups=True, **kwargs):
104+
def apply(self, func, *args, include_groups=True, _is_transform=False, **kwargs):
105105
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions
106106
if not callable(func):
107107
raise NotImplementedError("No support for non-callable `func`")
@@ -117,6 +117,7 @@ def apply(self, func, *args, include_groups=True, **kwargs):
117117
# TODO(https://github.com/modin-project/modin/issues/7096):
118118
# upstream the series_groupby param to Modin
119119
series_groupby=True,
120+
is_transform=_is_transform,
120121
)
121122
)
122123
if dataframe_result.columns.equals(pandas.Index([MODIN_UNNAMED_SERIES_LABEL])):

0 commit comments

Comments
 (0)