Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
- Fixed a bug that `DataFrameReader.dbapi` (PuPr) is not compatible with oracledb 3.4.0.
- Fixed a bug where `modin` would unintentionally be imported during session initialization in some scenarios.
- Fixed a bug where `session.udf|udtf|udaf|sproc.register` failed when an extra session argument was passed. These methods do not expect a session argument; please remove it if provided.
- Fixed a bug in `DataFrameGroupBuy.agg` where func is a list of tuples used to set the names of the output columns.

#### Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,8 +898,10 @@ def _is_supported_snowflake_agg_func(
"""
if isinstance(agg_func, tuple) and len(agg_func) == 2:
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the second part of the named aggregation.
agg_func = agg_func[0]
# take the aggregation part of the named aggregation.
agg_func = (
agg_func.func if isinstance(agg_func, AggFuncWithLabel) else agg_func[1]
)

if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None:
return AggregationSupportResult(
Expand Down Expand Up @@ -1381,10 +1383,15 @@ def get_agg_func_to_col_map(
def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str:
"""
Returns the friendly name for the aggr function. For example, if it is a callable, it will return __name__
otherwise the same string name value.
otherwise the same string name value. If aggfunc is a tuple, treat as named aggregation and return
the first part of the name.
"""
return (
getattr(aggfunc, "__name__", str(aggfunc))
getattr(
aggfunc,
"__name__",
str(aggfunc[0]) if isinstance(aggfunc, tuple) else str(aggfunc),
)
if not isinstance(aggfunc, str)
else aggfunc
)
Expand Down Expand Up @@ -1536,7 +1543,12 @@ def generate_column_agg_info(
for func_info, label, identifier in zip(
agg_func_list, agg_col_labels, agg_col_identifiers
):
func = func_info.func
# If func_info.func is a tuple, treat as named aggregation and return the aggregate function
func = (
func_info.func[1]
if isinstance(func_info.func, tuple)
else func_info.func
)
is_dummy_agg = func_info.is_dummy_agg
agg_func_col = pandas_lit(None) if is_dummy_agg else quoted_identifier
snowflake_agg_func = get_snowflake_agg_func(func, agg_kwargs, axis=0)
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/modin/groupby/test_groupby_basic_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,3 +1321,23 @@ def test_groupby_timedelta_var(self):
),
lambda df: df.groupby("A").var(),
)

@sql_count_checker(query_count=1)
def test_groupby_agg_named_aggregation_implicit(self):
eval_snowpark_pandas_result(
*create_test_dfs(
{
"team": ["A", "B"],
"score": [10, 15],
}
),
lambda df: df.groupby("team").agg({"score": [("total_score", "sum")]}),
)

@sql_count_checker(query_count=1)
def test_groupby_agg_named_aggregation_explicit(self):
agg1 = pd.NamedAgg(column="score", aggfunc="sum")
eval_snowpark_pandas_result(
*create_test_dfs({"team": ["A", "B"], "score": [10, 15]}),
lambda df: df.groupby("team").agg(total_score=agg1),
)