Skip to content

Commit e0d75a5

Browse files
add more tests
1 parent 5922c4c commit e0d75a5

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,9 @@ def _is_supported_snowflake_agg_func(
899899
if isinstance(agg_func, tuple) and len(agg_func) == 2:
900900
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
901901
# take the second part of the named aggregation.
902-
agg_func = agg_func[1]
902+
agg_func = (
903+
agg_func.func if isinstance(agg_func, AggFuncWithLabel) else agg_func[1]
904+
)
903905

904906
if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None:
905907
return AggregationSupportResult(

tests/integ/modin/groupby/test_groupby_basic_agg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1323,7 +1323,7 @@ def test_groupby_timedelta_var(self):
13231323
)
13241324

13251325
@sql_count_checker(query_count=1)
1326-
def test_groupby_agg_named_aggregation_codepath(self):
1326+
def test_groupby_agg_named_aggregation_implicit(self):
13271327
eval_snowpark_pandas_result(
13281328
*create_test_dfs(
13291329
{
@@ -1333,3 +1333,11 @@ def test_groupby_agg_named_aggregation_codepath(self):
13331333
),
13341334
lambda df: df.groupby("team").agg({"score": [("total_score", "sum")]}),
13351335
)
1336+
1337+
@sql_count_checker(query_count=1)
1338+
def test_groupby_agg_named_aggregation_explicit(self):
1339+
agg1 = pd.NamedAgg(column="score", aggfunc="sum")
1340+
eval_snowpark_pandas_result(
1341+
*create_test_dfs({"team": ["A", "B"], "score": [10, 15]}),
1342+
lambda df: df.groupby("team").agg(total_score=agg1),
1343+
)

0 commit comments

Comments
 (0)