Skip to content

Commit 35a5c1e

Browse files
series groupby
1 parent 4ffe411 commit 35a5c1e

File tree

2 files changed

+291
-0
lines changed

2 files changed

+291
-0
lines changed

src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@
5959
)
6060
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
6161
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS,
62+
UnsupportedArgsRule,
63+
_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE,
64+
register_query_compiler_method_not_implemented,
65+
)
66+
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
67+
check_is_groupby_supported_by_snowflake,
6268
)
6369
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
6470
from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import (
@@ -1549,6 +1555,22 @@ def fillna(
15491555

15501556
# Snowpark pandas defines a custom GroupBy object
15511557
@register_series_accessor("groupby")
1558+
@register_query_compiler_method_not_implemented(
1559+
"Series",
1560+
"groupby",
1561+
UnsupportedArgsRule(
1562+
unsupported_conditions=[
1563+
(
1564+
lambda args: not check_is_groupby_supported_by_snowflake(
1565+
args.get("by"),
1566+
args.get("level"),
1567+
args.get("axis", 0),
1568+
),
1569+
f"Groupby {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}",
1570+
)
1571+
]
1572+
),
1573+
)
15521574
def groupby(
15531575
self,
15541576
by=None,

tests/integ/modin/hybrid/test_switch_operations.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,158 @@ def test_auto_switch_unsupported_series(method, kwargs):
10311031
)
10321032

10331033

1034+
@pytest.mark.parametrize(
1035+
"groupby_kwargs",
1036+
[
1037+
{"level": 0},
1038+
{"by": pd.Grouper()},
1039+
],
1040+
)
1041+
def test_auto_switch_supported_series_groupby(groupby_kwargs):
1042+
# Test supported SeriesGroupBy operations that should stay on Snowflake backend.
1043+
test_data = [1, 2, 3, 4, 5, 6]
1044+
1045+
with SqlCounter(query_count=0):
1046+
series = pd.Series(test_data).move_to("Snowflake")
1047+
assert series.get_backend() == "Snowflake"
1048+
1049+
_test_stay_cost(
1050+
data_obj=series,
1051+
api_cls_name="Series",
1052+
method_name="groupby",
1053+
args=groupby_kwargs,
1054+
expected_cost=QCCoercionCost.COST_ZERO,
1055+
)
1056+
1057+
groupby_obj = series.groupby(**groupby_kwargs)
1058+
assert groupby_obj.get_backend() == "Snowflake"
1059+
1060+
1061+
@pytest.mark.parametrize(
1062+
"groupby_kwargs",
1063+
[
1064+
{"level": 0, "axis": 1},
1065+
{"by": [1, 1, 2, 2, 3, 3], "level": 0},
1066+
{"by": lambda x: x % 2},
1067+
{"by": np.array([1, 2, 1, 2, 1, 2])},
1068+
{"by": pd.Grouper(axis=1)},
1069+
],
1070+
)
1071+
def test_auto_switch_unsupported_series_groupby(groupby_kwargs):
1072+
# Test unsupported SeriesGroupBy operations that should switch to Pandas backend.
1073+
test_data = [1, 2, 3, 4, 5, 6]
1074+
1075+
with SqlCounter(query_count=1):
1076+
series = pd.Series(test_data).move_to("Snowflake")
1077+
assert series.get_backend() == "Snowflake"
1078+
1079+
# Convert list to Snowpark pandas Series in groupby_kwargs
1080+
converted_kwargs = groupby_kwargs.copy()
1081+
if "by" in converted_kwargs and isinstance(converted_kwargs["by"], list):
1082+
converted_kwargs["by"] = pd.Series(converted_kwargs["by"])
1083+
1084+
_test_stay_cost(
1085+
data_obj=series,
1086+
api_cls_name="Series",
1087+
method_name="groupby",
1088+
args=converted_kwargs,
1089+
expected_cost=QCCoercionCost.COST_IMPOSSIBLE,
1090+
)
1091+
1092+
pandas_series = pd.Series(test_data)
1093+
_test_move_to_me_cost(
1094+
pandas_qc=pandas_series._query_compiler,
1095+
api_cls_name="Series",
1096+
method_name="groupby",
1097+
args=converted_kwargs,
1098+
expected_cost=QCCoercionCost.COST_IMPOSSIBLE,
1099+
)
1100+
1101+
groupby_obj = series.groupby(**converted_kwargs)
1102+
assert groupby_obj.get_backend() == "Pandas"
1103+
1104+
1105+
@pytest.mark.parametrize(
1106+
"method,method_kwargs, groupby_kwargs, query_count",
1107+
[
1108+
("agg", {"func": "sum"}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1109+
("agg", {"func": "sum"}, {"by": lambda x: x % 2}, 1),
1110+
(
1111+
"apply",
1112+
{"func": lambda x: x.sum()},
1113+
{"by": [1, 1, 2, 2, 3, 3], "level": 0},
1114+
1,
1115+
),
1116+
("apply", {"func": lambda x: x.sum()}, {"by": lambda x: x % 2}, 1),
1117+
("size", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1118+
("size", {}, {"by": lambda x: x % 2}, 1),
1119+
("value_counts", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1120+
("unique", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1121+
("unique", {}, {"by": lambda x: x % 2}, 1),
1122+
("cummin", {}, {"by": lambda x: x % 2}, 1),
1123+
("cummin", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1124+
("cumsum", {}, {"by": lambda x: x % 2}, 1),
1125+
("cumsum", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1126+
("cummax", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1127+
("cummax", {}, {"by": lambda x: x % 2}, 1),
1128+
("cumcount", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1129+
("cumcount", {}, {"by": lambda x: x % 2}, 1),
1130+
("rank", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1131+
("rank", {}, {"by": lambda x: x % 2}, 1),
1132+
("shift", {}, {"by": [1, 1, 2, 2, 3, 3], "level": 0}, 1),
1133+
("shift", {}, {"by": lambda x: x % 2}, 1),
1134+
],
1135+
)
1136+
def test_auto_switch_unsupported_series_groupby_with_supported_method(
1137+
method, method_kwargs, groupby_kwargs, query_count
1138+
):
1139+
# Test unsupported SeriesGroupBy operations with supported methods that should switch to Pandas backend.
1140+
with SqlCounter(query_count=query_count):
1141+
test_data = [1, 2, 3, 4, 5, 6]
1142+
1143+
series = pd.Series(test_data).move_to("Snowflake")
1144+
assert series.get_backend() == "Snowflake"
1145+
1146+
# Convert list to Snowpark pandas Series in groupby_kwargs
1147+
converted_kwargs = groupby_kwargs.copy()
1148+
if "by" in converted_kwargs and isinstance(converted_kwargs["by"], list):
1149+
converted_kwargs["by"] = pd.Series(converted_kwargs["by"])
1150+
1151+
_test_stay_cost(
1152+
data_obj=series,
1153+
api_cls_name="Series",
1154+
method_name="groupby",
1155+
args=converted_kwargs,
1156+
expected_cost=QCCoercionCost.COST_IMPOSSIBLE,
1157+
)
1158+
1159+
pandas_series = pd.Series(test_data)
1160+
_test_move_to_me_cost(
1161+
pandas_qc=pandas_series._query_compiler,
1162+
api_cls_name="Series",
1163+
method_name="groupby",
1164+
args=converted_kwargs,
1165+
expected_cost=QCCoercionCost.COST_IMPOSSIBLE,
1166+
)
1167+
1168+
groupby_obj = series.groupby(**converted_kwargs)
1169+
assert groupby_obj.get_backend() == "Pandas"
1170+
1171+
_test_expected_backend(
1172+
data_obj=groupby_obj,
1173+
method_name=method,
1174+
args=method_kwargs,
1175+
expected_backend="Pandas",
1176+
is_top_level=False,
1177+
)
1178+
1179+
eval_snowpark_pandas_result(
1180+
groupby_obj,
1181+
native_pd.Series(test_data).groupby(**groupby_kwargs),
1182+
lambda s: getattr(s, method)(**method_kwargs),
1183+
)
1184+
1185+
10341186
@pytest.mark.parametrize(
10351187
"method,kwargs,expected_reason",
10361188
[
@@ -1221,3 +1373,120 @@ def test_method_none_condition(self):
12211373
)
12221374
def test_method_callable_non_string_reason(self):
12231375
pass
1376+
1377+
1378+
@pytest.mark.parametrize(
1379+
"method,kwargs,expected_reason",
1380+
[
1381+
(
1382+
"fillna",
1383+
{"value": 0, "downcast": "infer"},
1384+
"Snowpark pandas fillna does not yet support the parameter combination because 'downcast' argument is not supported yet in Snowpark pandas",
1385+
),
1386+
(
1387+
"first",
1388+
{"min_count": 2},
1389+
"does not yet support min_count",
1390+
),
1391+
(
1392+
"last",
1393+
{"min_count": 2},
1394+
"does not yet support min_count",
1395+
),
1396+
(
1397+
"shift",
1398+
{"freq": "D"},
1399+
"Snowpark pandas shift does not yet support the parameter combination because 'freq' argument is not supported yet in Snowpark pandas.",
1400+
),
1401+
],
1402+
)
1403+
@sql_count_checker(query_count=0)
1404+
def test_error_handling_unsupported_dataframe_groupby_method_when_auto_switch_disabled(
1405+
method, kwargs, expected_reason
1406+
):
1407+
# Test that unsupported DataFrame GroupBy args raise NotImplementedError when auto-switch is disabled.
1408+
with config_context(AutoSwitchBackend=False):
1409+
df = pd.DataFrame({"A": [1, 2, 3, 1, 2], "B": [4, 5, 6, 7, 8]}).move_to(
1410+
"Snowflake"
1411+
)
1412+
1413+
with pytest.raises(
1414+
NotImplementedError,
1415+
match=re.escape(expected_reason),
1416+
):
1417+
groupby_obj = df.groupby("A")
1418+
getattr(groupby_obj, method)(**kwargs)
1419+
1420+
1421+
@pytest.mark.parametrize(
1422+
"method, method_kwargs",
1423+
[
1424+
("fillna", {"value": 0}),
1425+
("first", {}),
1426+
("last", {}),
1427+
("shift", {}),
1428+
("apply", {"func": lambda x: x.sum()}),
1429+
("size", {}),
1430+
("get_group", {"name": 1}),
1431+
("nunique", {}),
1432+
("any", {}),
1433+
("all", {}),
1434+
("cummin", {}),
1435+
("cumsum", {}),
1436+
("cummax", {}),
1437+
("cumcount", {}),
1438+
("rank", {}),
1439+
("value_counts", {}),
1440+
("pct_change", {}),
1441+
],
1442+
)
1443+
@sql_count_checker(query_count=0)
1444+
def test_error_handling_unsupported_dataframe_groupby_with_supported_method_when_auto_switch_disabled(
1445+
method, method_kwargs
1446+
):
1447+
# Test that unsupported DataFrame GroupBy args raise NotImplementedError when auto-switch is disabled.
1448+
with config_context(AutoSwitchBackend=False):
1449+
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}).move_to("Snowflake")
1450+
1451+
with pytest.raises(
1452+
NotImplementedError,
1453+
match=re.escape(
1454+
"does not yet support axis == 1, by != None and level != None, or by containing any non-pandas hashable labels."
1455+
),
1456+
):
1457+
groupby_obj = df.groupby("A", level=0)
1458+
getattr(groupby_obj, method)(**method_kwargs)
1459+
1460+
1461+
@pytest.mark.parametrize(
1462+
"method, method_kwargs",
1463+
[
1464+
("apply", {"func": lambda x: x.sum()}),
1465+
("size", {}),
1466+
("value_counts", {}),
1467+
("agg", {"func": "sum"}),
1468+
("cummin", {}),
1469+
("cumsum", {}),
1470+
("cummax", {}),
1471+
("cumcount", {}),
1472+
("rank", {}),
1473+
("shift", {}),
1474+
("unique", {}),
1475+
],
1476+
)
1477+
@sql_count_checker(query_count=0)
1478+
def test_error_handling_unsupported_series_groupby_with_supported_method_when_auto_switch_disabled(
1479+
method, method_kwargs
1480+
):
1481+
# Test that unsupported SeriesGroupBy args raise NotImplementedError when auto-switch is disabled.
1482+
with config_context(AutoSwitchBackend=False):
1483+
series = pd.Series([1, 2, 3, 4, 5, 6]).move_to("Snowflake")
1484+
1485+
with pytest.raises(
1486+
NotImplementedError,
1487+
match=re.escape(
1488+
"does not yet support axis == 1, by != None and level != None, or by containing any non-pandas hashable labels."
1489+
),
1490+
):
1491+
groupby_obj = series.groupby(by=pd.Series([1, 1, 2, 2, 3, 3]), level=0)
1492+
getattr(groupby_obj, method)(**method_kwargs)

0 commit comments

Comments
 (0)