Skip to content

Commit 8c47bb7

Browse files
authored
SNOW-2205628: Exclude grouping columns in aggregation (#3585)
1 parent ca7e49a commit 8c47bb7

File tree

2 files changed

+256
-36
lines changed

2 files changed

+256
-36
lines changed

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -180,21 +180,25 @@ def _to_df(
180180
agg_exprs: List[Expression],
181181
_ast_stmt: Optional[proto.Bind] = None,
182182
_emit_ast: bool = False,
183+
**kwargs,
183184
) -> DataFrame:
185+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
184186
aliased_agg = []
185-
for grouping_expr in self._grouping_exprs:
186-
if isinstance(grouping_expr, GroupingSetsExpression):
187-
# avoid doing list(set(grouping_expr.args)) because it will change the order
188-
gr_used = set()
189-
gr_uniq = [
190-
a
191-
for arg in grouping_expr.args
192-
for a in arg
193-
if a not in gr_used and (gr_used.add(a) or True)
194-
]
195-
aliased_agg.extend(gr_uniq)
196-
else:
197-
aliased_agg.append(grouping_expr)
187+
188+
if not exclude_grouping_columns:
189+
for grouping_expr in self._grouping_exprs:
190+
if isinstance(grouping_expr, GroupingSetsExpression):
191+
# avoid doing list(set(grouping_expr.args)) because it will change the order
192+
gr_used = set()
193+
gr_uniq = [
194+
a
195+
for arg in grouping_expr.args
196+
for a in arg
197+
if a not in gr_used and (gr_used.add(a) or True)
198+
]
199+
aliased_agg.extend(gr_uniq)
200+
else:
201+
aliased_agg.append(grouping_expr)
198202

199203
aliased_agg.extend(agg_exprs)
200204

@@ -263,6 +267,7 @@ def agg(
263267
*exprs: Union[Column, Tuple[ColumnOrName, str], Dict[str, str]],
264268
_ast_stmt: Optional[proto.Bind] = None,
265269
_emit_ast: bool = True,
270+
**kwargs,
266271
) -> DataFrame:
267272
"""Returns a :class:`DataFrame` with computed aggregates. See examples in :meth:`DataFrame.group_by`.
268273
@@ -283,6 +288,7 @@ def agg(
283288
- :meth:`DataFrame.agg`
284289
- :meth:`DataFrame.group_by`
285290
"""
291+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
286292

287293
exprs, is_variadic = parse_positional_args_to_list_variadic(*exprs)
288294

@@ -323,7 +329,11 @@ def agg(
323329
)
324330
agg_exprs.append(_str_to_expr(e[1], _emit_ast)(col_expr))
325331

326-
df = self._to_df(agg_exprs, _emit_ast=False)
332+
df = self._to_df(
333+
agg_exprs,
334+
exclude_grouping_columns=exclude_grouping_columns,
335+
_emit_ast=False,
336+
)
327337
df._ops_after_agg = set()
328338

329339
if _emit_ast:
@@ -649,40 +659,93 @@ def pivot(
649659

650660
@relational_group_df_api_usage
651661
@publicapi
652-
def avg(self, *cols: ColumnOrName, _emit_ast: bool = True) -> DataFrame:
653-
"""Return the average for the specified numeric columns."""
654-
return self._non_empty_argument_function("avg", *cols, _emit_ast=_emit_ast)
662+
def avg(self, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs) -> DataFrame:
663+
"""Return the average for the specified numeric columns.
664+
665+
Args:
666+
cols: The columns to calculate average for.
667+
"""
668+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
669+
return self._non_empty_argument_function(
670+
"avg",
671+
*cols,
672+
exclude_grouping_columns=exclude_grouping_columns,
673+
_emit_ast=_emit_ast,
674+
)
655675

656676
mean = avg
657677

658678
@relational_group_df_api_usage
659679
@publicapi
660-
def sum(self, *cols: ColumnOrName, _emit_ast: bool = True) -> DataFrame:
661-
"""Return the sum for the specified numeric columns."""
662-
return self._non_empty_argument_function("sum", *cols, _emit_ast=_emit_ast)
680+
def sum(self, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs) -> DataFrame:
681+
"""Return the sum for the specified numeric columns.
682+
683+
Args:
684+
cols: The columns to calculate sum for.
685+
"""
686+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
687+
return self._non_empty_argument_function(
688+
"sum",
689+
*cols,
690+
exclude_grouping_columns=exclude_grouping_columns,
691+
_emit_ast=_emit_ast,
692+
)
663693

664694
@relational_group_df_api_usage
665695
@publicapi
666-
def median(self, *cols: ColumnOrName, _emit_ast: bool = True) -> DataFrame:
667-
"""Return the median for the specified numeric columns."""
668-
return self._non_empty_argument_function("median", *cols, _emit_ast=_emit_ast)
696+
def median(
697+
self, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs
698+
) -> DataFrame:
699+
"""Return the median for the specified numeric columns.
700+
701+
Args:
702+
cols: The columns to calculate median for.
703+
"""
704+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
705+
return self._non_empty_argument_function(
706+
"median",
707+
*cols,
708+
exclude_grouping_columns=exclude_grouping_columns,
709+
_emit_ast=_emit_ast,
710+
)
669711

670712
@relational_group_df_api_usage
671713
@publicapi
672-
def min(self, *cols: ColumnOrName, _emit_ast: bool = True) -> DataFrame:
673-
"""Return the min for the specified numeric columns."""
674-
return self._non_empty_argument_function("min", *cols, _emit_ast=_emit_ast)
714+
def min(self, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs) -> DataFrame:
715+
"""Return the min for the specified numeric columns.
716+
717+
Args:
718+
cols: The columns to calculate min for.
719+
"""
720+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
721+
return self._non_empty_argument_function(
722+
"min",
723+
*cols,
724+
exclude_grouping_columns=exclude_grouping_columns,
725+
_emit_ast=_emit_ast,
726+
)
675727

676728
@relational_group_df_api_usage
677729
@publicapi
678-
def max(self, *cols: ColumnOrName, _emit_ast: bool = True) -> DataFrame:
679-
"""Return the max for the specified numeric columns."""
680-
return self._non_empty_argument_function("max", *cols, _emit_ast=_emit_ast)
730+
def max(self, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs) -> DataFrame:
731+
"""Return the max for the specified numeric columns.
732+
733+
Args:
734+
cols: The columns to calculate max for.
735+
"""
736+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
737+
return self._non_empty_argument_function(
738+
"max",
739+
*cols,
740+
exclude_grouping_columns=exclude_grouping_columns,
741+
_emit_ast=_emit_ast,
742+
)
681743

682744
@relational_group_df_api_usage
683745
@publicapi
684-
def count(self, _emit_ast: bool = True) -> DataFrame:
746+
def count(self, _emit_ast: bool = True, **kwargs) -> DataFrame:
685747
"""Return the number of rows for each group."""
748+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
686749
df = self._to_df(
687750
[
688751
Alias(
@@ -692,6 +755,7 @@ def count(self, _emit_ast: bool = True) -> DataFrame:
692755
"count",
693756
)
694757
],
758+
exclude_grouping_columns=exclude_grouping_columns,
695759
_emit_ast=False,
696760
)
697761
df._ops_after_agg = set()
@@ -709,27 +773,38 @@ def count(self, _emit_ast: bool = True) -> DataFrame:
709773
return df
710774

711775
@publicapi
712-
def function(self, agg_name: str, _emit_ast: bool = True) -> Callable:
776+
def function(self, agg_name: str, _emit_ast: bool = True, **kwargs) -> Callable:
713777
"""Computes the builtin aggregate ``agg_name`` over the specified columns. Use
714778
this function to invoke any aggregates not explicitly listed in this class.
715779
See examples in :meth:`DataFrame.group_by`.
780+
781+
Args:
782+
agg_name: The name of the aggregate function.
716783
"""
717-
return lambda *cols: self._function(agg_name, *cols, _emit_ast=_emit_ast)
784+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
785+
return lambda *cols: self._function(
786+
agg_name,
787+
*cols,
788+
exclude_grouping_columns=exclude_grouping_columns,
789+
_emit_ast=_emit_ast,
790+
)
718791

719792
builtin = function
720793

721794
@publicapi
722795
def _function(
723-
self, agg_name: str, *cols: ColumnOrName, _emit_ast: bool = True
796+
self, agg_name: str, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs
724797
) -> DataFrame:
798+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
725799
agg_exprs = []
726800
for c in cols:
727801
c_expr = Column(c)._expression if isinstance(c, str) else c._expression
728802
expr = functions._call_function(
729803
agg_name, c_expr, _emit_ast=False
730804
)._expression
731805
agg_exprs.append(expr)
732-
df = self._to_df(agg_exprs)
806+
807+
df = self._to_df(agg_exprs, exclude_grouping_columns=exclude_grouping_columns)
733808
df._ops_after_agg = set()
734809

735810
if _emit_ast:
@@ -750,14 +825,19 @@ def _function(
750825

751826
@publicapi
752827
def _non_empty_argument_function(
753-
self, func_name: str, *cols: ColumnOrName, _emit_ast: bool = True
828+
self, func_name: str, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs
754829
) -> DataFrame:
830+
exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False)
755831
if not cols:
756832
raise ValueError(
757833
f"You must pass a list of one or more Columns to function: {func_name}"
758834
)
759835
else:
760-
return self.builtin(func_name, _emit_ast=_emit_ast)(*cols)
836+
return self.builtin(
837+
func_name,
838+
exclude_grouping_columns=exclude_grouping_columns,
839+
_emit_ast=_emit_ast,
840+
)(*cols)
761841

762842
def _set_ast_ref(self, expr_builder: proto.Expr) -> None:
763843
"""

tests/integ/test_df_aggregate.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,3 +928,143 @@ def test_filter_sort_limit_snowpark_connect_compatible(session):
928928

929929
finally:
930930
context._is_snowpark_connect_compatible_mode = original_value
931+
932+
933+
@pytest.mark.skipif(
934+
"config.getoption('local_testing_mode', default=False)",
935+
reason="exclude_grouping_columns is not supported",
936+
)
937+
def test_group_by_exclude_grouping_columns(session):
938+
"""Test the exclude_grouping_columns parameter for all aggregate functions."""
939+
940+
# Create test data
941+
df = session.create_dataframe(
942+
[
943+
("A", "X", 1, 100),
944+
("A", "X", 2, 200),
945+
("A", "Y", 3, 300),
946+
("B", "X", 4, 400),
947+
("B", "Y", 5, 500),
948+
("B", "Y", 6, 600),
949+
],
950+
schema=["group1", "group2", "value1", "value2"],
951+
)
952+
953+
# Test agg() with exclude_grouping_columns
954+
# Default behavior (include grouping columns)
955+
result_default = df.group_by("group1").agg(sum_("value1").alias("sum_v1")).collect()
956+
assert len(result_default[0]) == 2 # group1 + sum_v1
957+
Utils.check_answer(result_default, [Row("A", 6), Row("B", 15)])
958+
959+
# Exclude grouping columns
960+
result_exclude = (
961+
df.group_by("group1")
962+
.agg(sum_("value1").alias("sum_v1"), exclude_grouping_columns=True)
963+
.collect()
964+
)
965+
assert len(result_exclude[0]) == 1 # only sum_v1
966+
print(result_exclude)
967+
Utils.check_answer(result_exclude, [Row(6), Row(15)])
968+
969+
# Test with multiple grouping columns
970+
result_multi_default = (
971+
df.group_by("group1", "group2").agg(sum_("value1").alias("sum_v1")).collect()
972+
)
973+
assert len(result_multi_default[0]) == 3 # group1 + group2 + sum_v1
974+
975+
result_multi_exclude = (
976+
df.group_by("group1", "group2")
977+
.agg(sum_("value1").alias("sum_v1"), exclude_grouping_columns=True)
978+
.collect()
979+
)
980+
assert len(result_multi_exclude[0]) == 1 # only sum_v1
981+
# Group by produces [('A', 'X', 3), ('A', 'Y', 3), ('B', 'X', 4), ('B', 'Y', 11)]
982+
Utils.check_answer(result_multi_exclude, [Row(3), Row(3), Row(4), Row(11)])
983+
984+
# Test with multiple aggregations
985+
result_multi_agg = (
986+
df.group_by("group1")
987+
.agg(
988+
sum_("value1").alias("sum_v1"),
989+
avg("value2").alias("avg_v2"),
990+
exclude_grouping_columns=True,
991+
)
992+
.collect()
993+
)
994+
assert len(result_multi_agg[0]) == 2 # sum_v1 + avg_v2
995+
Utils.check_answer(result_multi_agg, [Row(6, 200.0), Row(15, 500.0)])
996+
997+
# Test count() with exclude_grouping_columns
998+
result_count_default = df.group_by("group1").count().collect()
999+
assert len(result_count_default[0]) == 2 # group1 + count
1000+
Utils.check_answer(result_count_default, [Row("A", 3), Row("B", 3)])
1001+
1002+
result_count_exclude = (
1003+
df.group_by("group1").count(exclude_grouping_columns=True).collect()
1004+
)
1005+
assert len(result_count_exclude[0]) == 1 # only count
1006+
Utils.check_answer(result_count_exclude, [Row(3), Row(3)])
1007+
1008+
# Test avg() with exclude_grouping_columns
1009+
result_avg_default = df.group_by("group1").avg("value1").collect()
1010+
assert len(result_avg_default[0]) == 2 # group1 + avg
1011+
1012+
result_avg_exclude = (
1013+
df.group_by("group1").avg("value1", exclude_grouping_columns=True).collect()
1014+
)
1015+
assert len(result_avg_exclude[0]) == 1 # only avg
1016+
Utils.check_answer(result_avg_exclude, [Row(2.0), Row(5.0)])
1017+
1018+
# Test sum() with exclude_grouping_columns
1019+
result_sum_default = df.group_by("group1").sum("value1", "value2").collect()
1020+
assert len(result_sum_default[0]) == 3 # group1 + sum(value1) + sum(value2)
1021+
1022+
result_sum_exclude = (
1023+
df.group_by("group1")
1024+
.sum("value1", "value2", exclude_grouping_columns=True)
1025+
.collect()
1026+
)
1027+
assert len(result_sum_exclude[0]) == 2 # only sums
1028+
Utils.check_answer(result_sum_exclude, [Row(6, 600), Row(15, 1500)])
1029+
1030+
# Test min() with exclude_grouping_columns
1031+
result_min_default = df.group_by("group1").min("value1").collect()
1032+
assert len(result_min_default[0]) == 2 # group1 + min
1033+
1034+
result_min_exclude = (
1035+
df.group_by("group1").min("value1", exclude_grouping_columns=True).collect()
1036+
)
1037+
assert len(result_min_exclude[0]) == 1 # only min
1038+
Utils.check_answer(result_min_exclude, [Row(1), Row(4)])
1039+
1040+
# Test max() with exclude_grouping_columns
1041+
result_max_default = df.group_by("group1").max("value1").collect()
1042+
assert len(result_max_default[0]) == 2 # group1 + max
1043+
1044+
result_max_exclude = (
1045+
df.group_by("group1").max("value1", exclude_grouping_columns=True).collect()
1046+
)
1047+
assert len(result_max_exclude[0]) == 1 # only max
1048+
Utils.check_answer(result_max_exclude, [Row(3), Row(6)])
1049+
1050+
# Test median() with exclude_grouping_columns
1051+
result_median_default = df.group_by("group1").median("value1").collect()
1052+
assert len(result_median_default[0]) == 2 # group1 + median
1053+
1054+
result_median_exclude = (
1055+
df.group_by("group1").median("value1", exclude_grouping_columns=True).collect()
1056+
)
1057+
assert len(result_median_exclude[0]) == 1 # only median
1058+
Utils.check_answer(result_median_exclude, [Row(2.0), Row(5.0)])
1059+
1060+
# Test function() / builtin() with exclude_grouping_columns
1061+
result_builtin_default = df.group_by("group1").builtin("sum")("value1").collect()
1062+
assert len(result_builtin_default[0]) == 2 # group1 + sum
1063+
1064+
result_builtin_exclude = (
1065+
df.group_by("group1")
1066+
.builtin("sum", exclude_grouping_columns=True)("value1")
1067+
.collect()
1068+
)
1069+
assert len(result_builtin_exclude[0]) == 1 # only sum
1070+
Utils.check_answer(result_builtin_exclude, [Row(6), Row(15)])

0 commit comments

Comments
 (0)