@@ -180,21 +180,25 @@ def _to_df(
180180 agg_exprs : List [Expression ],
181181 _ast_stmt : Optional [proto .Bind ] = None ,
182182 _emit_ast : bool = False ,
183+ ** kwargs ,
183184 ) -> DataFrame :
185+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
184186 aliased_agg = []
185- for grouping_expr in self ._grouping_exprs :
186- if isinstance (grouping_expr , GroupingSetsExpression ):
187- # avoid doing list(set(grouping_expr.args)) because it will change the order
188- gr_used = set ()
189- gr_uniq = [
190- a
191- for arg in grouping_expr .args
192- for a in arg
193- if a not in gr_used and (gr_used .add (a ) or True )
194- ]
195- aliased_agg .extend (gr_uniq )
196- else :
197- aliased_agg .append (grouping_expr )
187+
188+ if not exclude_grouping_columns :
189+ for grouping_expr in self ._grouping_exprs :
190+ if isinstance (grouping_expr , GroupingSetsExpression ):
191+ # avoid doing list(set(grouping_expr.args)) because it will change the order
192+ gr_used = set ()
193+ gr_uniq = [
194+ a
195+ for arg in grouping_expr .args
196+ for a in arg
197+ if a not in gr_used and (gr_used .add (a ) or True )
198+ ]
199+ aliased_agg .extend (gr_uniq )
200+ else :
201+ aliased_agg .append (grouping_expr )
198202
199203 aliased_agg .extend (agg_exprs )
200204
@@ -263,6 +267,7 @@ def agg(
263267 * exprs : Union [Column , Tuple [ColumnOrName , str ], Dict [str , str ]],
264268 _ast_stmt : Optional [proto .Bind ] = None ,
265269 _emit_ast : bool = True ,
270+ ** kwargs ,
266271 ) -> DataFrame :
267272 """Returns a :class:`DataFrame` with computed aggregates. See examples in :meth:`DataFrame.group_by`.
268273
@@ -283,6 +288,7 @@ def agg(
283288 - :meth:`DataFrame.agg`
284289 - :meth:`DataFrame.group_by`
285290 """
291+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
286292
287293 exprs , is_variadic = parse_positional_args_to_list_variadic (* exprs )
288294
@@ -323,7 +329,11 @@ def agg(
323329 )
324330 agg_exprs .append (_str_to_expr (e [1 ], _emit_ast )(col_expr ))
325331
326- df = self ._to_df (agg_exprs , _emit_ast = False )
332+ df = self ._to_df (
333+ agg_exprs ,
334+ exclude_grouping_columns = exclude_grouping_columns ,
335+ _emit_ast = False ,
336+ )
327337 df ._ops_after_agg = set ()
328338
329339 if _emit_ast :
@@ -649,40 +659,93 @@ def pivot(
649659
650660 @relational_group_df_api_usage
651661 @publicapi
652- def avg (self , * cols : ColumnOrName , _emit_ast : bool = True ) -> DataFrame :
653- """Return the average for the specified numeric columns."""
654- return self ._non_empty_argument_function ("avg" , * cols , _emit_ast = _emit_ast )
662+ def avg (self , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs ) -> DataFrame :
663+ """Return the average for the specified numeric columns.
664+
665+ Args:
666+ cols: The columns to calculate average for.
667+ """
668+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
669+ return self ._non_empty_argument_function (
670+ "avg" ,
671+ * cols ,
672+ exclude_grouping_columns = exclude_grouping_columns ,
673+ _emit_ast = _emit_ast ,
674+ )
655675
656676 mean = avg
657677
658678 @relational_group_df_api_usage
659679 @publicapi
660- def sum (self , * cols : ColumnOrName , _emit_ast : bool = True ) -> DataFrame :
661- """Return the sum for the specified numeric columns."""
662- return self ._non_empty_argument_function ("sum" , * cols , _emit_ast = _emit_ast )
680+ def sum (self , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs ) -> DataFrame :
681+ """Return the sum for the specified numeric columns.
682+
683+ Args:
684+ cols: The columns to calculate sum for.
685+ """
686+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
687+ return self ._non_empty_argument_function (
688+ "sum" ,
689+ * cols ,
690+ exclude_grouping_columns = exclude_grouping_columns ,
691+ _emit_ast = _emit_ast ,
692+ )
663693
664694 @relational_group_df_api_usage
665695 @publicapi
666- def median (self , * cols : ColumnOrName , _emit_ast : bool = True ) -> DataFrame :
667- """Return the median for the specified numeric columns."""
668- return self ._non_empty_argument_function ("median" , * cols , _emit_ast = _emit_ast )
696+ def median (
697+ self , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs
698+ ) -> DataFrame :
699+ """Return the median for the specified numeric columns.
700+
701+ Args:
702+ cols: The columns to calculate median for.
703+ """
704+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
705+ return self ._non_empty_argument_function (
706+ "median" ,
707+ * cols ,
708+ exclude_grouping_columns = exclude_grouping_columns ,
709+ _emit_ast = _emit_ast ,
710+ )
669711
670712 @relational_group_df_api_usage
671713 @publicapi
672- def min (self , * cols : ColumnOrName , _emit_ast : bool = True ) -> DataFrame :
673- """Return the min for the specified numeric columns."""
674- return self ._non_empty_argument_function ("min" , * cols , _emit_ast = _emit_ast )
714+ def min (self , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs ) -> DataFrame :
715+ """Return the min for the specified numeric columns.
716+
717+ Args:
718+ cols: The columns to calculate min for.
719+ """
720+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
721+ return self ._non_empty_argument_function (
722+ "min" ,
723+ * cols ,
724+ exclude_grouping_columns = exclude_grouping_columns ,
725+ _emit_ast = _emit_ast ,
726+ )
675727
676728 @relational_group_df_api_usage
677729 @publicapi
678- def max (self , * cols : ColumnOrName , _emit_ast : bool = True ) -> DataFrame :
679- """Return the max for the specified numeric columns."""
680- return self ._non_empty_argument_function ("max" , * cols , _emit_ast = _emit_ast )
730+ def max (self , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs ) -> DataFrame :
731+ """Return the max for the specified numeric columns.
732+
733+ Args:
734+ cols: The columns to calculate max for.
735+ """
736+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
737+ return self ._non_empty_argument_function (
738+ "max" ,
739+ * cols ,
740+ exclude_grouping_columns = exclude_grouping_columns ,
741+ _emit_ast = _emit_ast ,
742+ )
681743
682744 @relational_group_df_api_usage
683745 @publicapi
684- def count (self , _emit_ast : bool = True ) -> DataFrame :
746+ def count (self , _emit_ast : bool = True , ** kwargs ) -> DataFrame :
685747 """Return the number of rows for each group."""
748+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
686749 df = self ._to_df (
687750 [
688751 Alias (
@@ -692,6 +755,7 @@ def count(self, _emit_ast: bool = True) -> DataFrame:
692755 "count" ,
693756 )
694757 ],
758+ exclude_grouping_columns = exclude_grouping_columns ,
695759 _emit_ast = False ,
696760 )
697761 df ._ops_after_agg = set ()
@@ -709,27 +773,38 @@ def count(self, _emit_ast: bool = True) -> DataFrame:
709773 return df
710774
711775 @publicapi
712- def function (self , agg_name : str , _emit_ast : bool = True ) -> Callable :
776+ def function (self , agg_name : str , _emit_ast : bool = True , ** kwargs ) -> Callable :
713777 """Computes the builtin aggregate ``agg_name`` over the specified columns. Use
714778 this function to invoke any aggregates not explicitly listed in this class.
715779 See examples in :meth:`DataFrame.group_by`.
780+
781+ Args:
782+ agg_name: The name of the aggregate function.
716783 """
717- return lambda * cols : self ._function (agg_name , * cols , _emit_ast = _emit_ast )
784+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
785+ return lambda * cols : self ._function (
786+ agg_name ,
787+ * cols ,
788+ exclude_grouping_columns = exclude_grouping_columns ,
789+ _emit_ast = _emit_ast ,
790+ )
718791
719792 builtin = function
720793
721794 @publicapi
722795 def _function (
723- self , agg_name : str , * cols : ColumnOrName , _emit_ast : bool = True
796+ self , agg_name : str , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs
724797 ) -> DataFrame :
798+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
725799 agg_exprs = []
726800 for c in cols :
727801 c_expr = Column (c )._expression if isinstance (c , str ) else c ._expression
728802 expr = functions ._call_function (
729803 agg_name , c_expr , _emit_ast = False
730804 )._expression
731805 agg_exprs .append (expr )
732- df = self ._to_df (agg_exprs )
806+
807+ df = self ._to_df (agg_exprs , exclude_grouping_columns = exclude_grouping_columns )
733808 df ._ops_after_agg = set ()
734809
735810 if _emit_ast :
@@ -750,14 +825,19 @@ def _function(
750825
751826 @publicapi
752827 def _non_empty_argument_function (
753- self , func_name : str , * cols : ColumnOrName , _emit_ast : bool = True
828+ self , func_name : str , * cols : ColumnOrName , _emit_ast : bool = True , ** kwargs
754829 ) -> DataFrame :
830+ exclude_grouping_columns = kwargs .get ("exclude_grouping_columns" , False )
755831 if not cols :
756832 raise ValueError (
757833 f"You must pass a list of one or more Columns to function: { func_name } "
758834 )
759835 else :
760- return self .builtin (func_name , _emit_ast = _emit_ast )(* cols )
836+ return self .builtin (
837+ func_name ,
838+ exclude_grouping_columns = exclude_grouping_columns ,
839+ _emit_ast = _emit_ast ,
840+ )(* cols )
761841
762842 def _set_ast_ref (self , expr_builder : proto .Expr ) -> None :
763843 """
0 commit comments