Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
- Fixed a bug where writing Snowpark pandas dataframes on the pandas backend with a column multiindex to Snowflake with `to_snowflake` would raise `KeyError`.
- Fixed a bug that `DataFrameReader.dbapi` (PuPr) is not compatible with oracledb 3.4.0.
- Fixed a bug where `modin` would unintentionally be imported during session initialization in some scenarios.
- Fixed a bug with `DataFrameGroupBy.agg` where tuples are treated as multiindex level using the named aggregation code path.

#### Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,9 @@ def _is_supported_snowflake_agg_func(
if isinstance(agg_func, tuple) and len(agg_func) == 2:
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the second part of the named aggregation.
agg_func = agg_func[0]
agg_func = (
agg_func.func if isinstance(agg_func, AggFuncWithLabel) else agg_func[1]
)

if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None:
return AggregationSupportResult(
Expand Down Expand Up @@ -1378,13 +1380,18 @@ def get_agg_func_to_col_map(
return agg_func_to_col_map


def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str:
def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> Union[str, Any, None]:
"""
Returns the friendly name for the aggr function. For example, if it is a callable, it will return __name__
otherwise the same string name value.
otherwise the same string name value. If aggfunc is a tuple, treat as named aggregation and return
the first part of the name.
"""
return (
getattr(aggfunc, "__name__", str(aggfunc))
getattr(
aggfunc,
"__name__",
aggfunc[0] if isinstance(aggfunc, tuple) else str(aggfunc),
)
if not isinstance(aggfunc, str)
else aggfunc
)
Expand Down Expand Up @@ -1536,7 +1543,12 @@ def generate_column_agg_info(
for func_info, label, identifier in zip(
agg_func_list, agg_col_labels, agg_col_identifiers
):
func = func_info.func
# If func_info.func is a tuple, treat as named aggregation and return the aggregate function
func = (
func_info.func[1]
if isinstance(func_info.func, tuple)
else func_info.func
)
is_dummy_agg = func_info.is_dummy_agg
agg_func_col = pandas_lit(None) if is_dummy_agg else quoted_identifier
snowflake_agg_func = get_snowflake_agg_func(func, agg_kwargs, axis=0)
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/modin/groupby/test_groupby_basic_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,3 +1321,23 @@ def test_groupby_timedelta_var(self):
),
lambda df: df.groupby("A").var(),
)

@sql_count_checker(query_count=1)
def test_groupby_agg_named_aggregation_implicit(self):
eval_snowpark_pandas_result(
*create_test_dfs(
{
"team": ["A", "B"],
"score": [10, 15],
}
),
lambda df: df.groupby("team").agg({"score": [("total_score", "sum")]}),
)

@sql_count_checker(query_count=1)
def test_groupby_agg_named_aggregation_explicit(self):
agg1 = pd.NamedAgg(column="score", aggfunc="sum")
eval_snowpark_pandas_result(
*create_test_dfs({"team": ["A", "B"], "score": [10, 15]}),
lambda df: df.groupby("team").agg(total_score=agg1),
)