Skip to content

Commit 9d6bb37

Browse files
SNOW-2362373: Fix bug with groupby.agg named aggregation path (#3930)
1 parent 28d30b6 commit 9d6bb37

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
- Fixed a bug that `DataFrameReader.dbapi` (PuPr) is not compatible with oracledb 3.4.0.
6868
- Fixed a bug where `modin` would unintentionally be imported during session initialization in some scenarios.
6969
- Fixed a bug where `session.udf|udtf|udaf|sproc.register` failed when an extra session argument was passed. These methods do not expect a session argument; please remove it if provided.
70+
- Fixed a bug in `DataFrameGroupBuy.agg` where func is a list of tuples used to set the names of the output columns.
7071

7172
#### Improvements
7273

src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -898,8 +898,10 @@ def _is_supported_snowflake_agg_func(
898898
"""
899899
if isinstance(agg_func, tuple) and len(agg_func) == 2:
900900
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
901-
# take the second part of the named aggregation.
902-
agg_func = agg_func[0]
901+
# take the aggregation part of the named aggregation.
902+
agg_func = (
903+
agg_func.func if isinstance(agg_func, AggFuncWithLabel) else agg_func[1]
904+
)
903905

904906
if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None:
905907
return AggregationSupportResult(
@@ -1381,10 +1383,15 @@ def get_agg_func_to_col_map(
13811383
def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str:
13821384
"""
13831385
Returns the friendly name for the aggr function. For example, if it is a callable, it will return __name__
1384-
otherwise the same string name value.
1386+
otherwise the same string name value. If aggfunc is a tuple, treat as named aggregation and return
1387+
the first part of the name.
13851388
"""
13861389
return (
1387-
getattr(aggfunc, "__name__", str(aggfunc))
1390+
getattr(
1391+
aggfunc,
1392+
"__name__",
1393+
str(aggfunc[0]) if isinstance(aggfunc, tuple) else str(aggfunc),
1394+
)
13881395
if not isinstance(aggfunc, str)
13891396
else aggfunc
13901397
)
@@ -1536,7 +1543,12 @@ def generate_column_agg_info(
15361543
for func_info, label, identifier in zip(
15371544
agg_func_list, agg_col_labels, agg_col_identifiers
15381545
):
1539-
func = func_info.func
1546+
# If func_info.func is a tuple, treat as named aggregation and return the aggregate function
1547+
func = (
1548+
func_info.func[1]
1549+
if isinstance(func_info.func, tuple)
1550+
else func_info.func
1551+
)
15401552
is_dummy_agg = func_info.is_dummy_agg
15411553
agg_func_col = pandas_lit(None) if is_dummy_agg else quoted_identifier
15421554
snowflake_agg_func = get_snowflake_agg_func(func, agg_kwargs, axis=0)

tests/integ/modin/groupby/test_groupby_basic_agg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,3 +1321,23 @@ def test_groupby_timedelta_var(self):
13211321
),
13221322
lambda df: df.groupby("A").var(),
13231323
)
1324+
1325+
@sql_count_checker(query_count=1)
1326+
def test_groupby_agg_named_aggregation_implicit(self):
1327+
eval_snowpark_pandas_result(
1328+
*create_test_dfs(
1329+
{
1330+
"team": ["A", "B"],
1331+
"score": [10, 15],
1332+
}
1333+
),
1334+
lambda df: df.groupby("team").agg({"score": [("total_score", "sum")]}),
1335+
)
1336+
1337+
@sql_count_checker(query_count=1)
1338+
def test_groupby_agg_named_aggregation_explicit(self):
1339+
agg1 = pd.NamedAgg(column="score", aggfunc="sum")
1340+
eval_snowpark_pandas_result(
1341+
*create_test_dfs({"team": ["A", "B"], "score": [10, 15]}),
1342+
lambda df: df.groupby("team").agg(total_score=agg1),
1343+
)

0 commit comments

Comments
 (0)