@@ -152,6 +152,7 @@ class CubeViewDef:
152152 creating the view
153153 :param agg_args: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`
154154 :param agg_kwargs: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`
155+ :param agg_multi: aggregation over multiple columns
155156 :param ignore_columns: ignore the following columns if known to overload the view
156157 :param keep_columns_in_index: keeps the columns even if there is only one unique value
157158 :param dropna: drops rows with nan if not relevant
@@ -174,6 +175,9 @@ def __init__(
174175 key_agg : Optional [Sequence [str ]] = None ,
175176 agg_args : Sequence [Any ] = ("sum" ,),
176177 agg_kwargs : Optional [Dict [str , Any ]] = None ,
178+ agg_multi : Optional [
179+ Dict [str , Callable [[pandas .core .groupby .DataFrameGroupBy ], pandas .Series ]]
180+ ] = None ,
177181 ignore_columns : Optional [Sequence [str ]] = None ,
178182 keep_columns_in_index : Optional [Sequence [str ]] = None ,
179183 dropna : bool = True ,
@@ -188,6 +192,7 @@ def __init__(
188192 self .key_agg = key_agg
189193 self .agg_args = agg_args
190194 self .agg_kwargs = agg_kwargs
195+ self .agg_multi = agg_multi
191196 self .dropna = dropna
192197 self .ignore_columns = ignore_columns
193198 self .keep_columns_in_index = keep_columns_in_index
@@ -468,6 +473,7 @@ def view(
468473 )
469474
470475 if key_agg :
476+ final_stack = True
471477 key_index = [
472478 c
473479 for c in self ._filter_column (view_def .key_index , self .keys_time )
@@ -483,14 +489,22 @@ def view(
483489 f"selected={ pprint .pformat (sorted (data_red .columns ))} ,\n --\n "
484490 f"keys={ pprint .pformat (sorted (self .keys_time ))} "
485491 )
486- data = data_red .groupby (keys_no_agg , as_index = False , dropna = False ).agg (
487- * view_def .agg_args , ** (view_def .agg_kwargs or {})
488- )
492+ grouped_data = data_red .groupby (keys_no_agg , as_index = True , dropna = False )
493+ data = grouped_data .agg (* view_def .agg_args , ** (view_def .agg_kwargs or {}))
494+ if view_def .agg_multi :
495+ append = []
496+ for k , f in view_def .agg_multi .items ():
497+ cv = grouped_data .apply (f , include_groups = False )
498+ append .append (cv .to_frame (k ))
499+ data = pandas .concat ([data , * append ], axis = 1 )
489500 set_all_keys = set (keys_no_agg )
501+ values = list (data .columns )
502+ data = data .reset_index (drop = False )
490503 else :
491504 key_index = self ._filter_column (view_def .key_index , self .keys_time )
492505 data = self .data [[* self .keys_time , * values ]]
493506 set_all_keys = set (self .keys_time )
507+ final_stack = False
494508
495509 assert set (key_index ) <= set_all_keys , (
496510 f"view_def.name={ view_def .name !r} , "
@@ -580,8 +594,17 @@ def view(
580594 piv = data .pivot (index = key_index [::- 1 ], columns = key_columns , values = values )
581595 if isinstance (piv , pandas .Series ):
582596 piv = piv .to_frame (name = "series" )
597+ names = list (piv .columns .names )
598+ assert (
599+ "METRICS" not in names
600+ ), f"Not implemented when a level METRICS already exists { names !r} "
601+ names [0 ] = "METRICS"
602+ piv .columns = piv .columns .set_names (names )
603+ if final_stack :
604+ piv = piv .stack ("METRICS" )
583605 if view_def .transpose :
584606 piv = piv .T
607+
585608 return (piv , view_def ) if return_view_def else piv
586609
587610 def _dropna (
@@ -1015,6 +1038,18 @@ def make_view_def(self, name: str) -> CubeViewDef:
10151038 )
10161039 )
10171040
1041+ def mean_weight (gr ):
1042+ weight = gr ["time_latency_eager" ]
1043+ x = gr ["speedup" ]
1044+ if x .shape [0 ] == 0 :
1045+ return np .nan
1046+ div = weight .sum ()
1047+ return (x * weight ).sum () / div
1048+
1049+ def mean_geo (gr ):
1050+ x = gr ["speedup" ]
1051+ return np .exp (np .log (x .dropna ()).mean ())
1052+
10181053 implemented_views = {
10191054 "agg-suite" : lambda : CubeViewDef (
10201055 key_index = index_cols ,
@@ -1024,6 +1059,7 @@ def make_view_def(self, name: str) -> CubeViewDef:
10241059 ignore_unique = True ,
10251060 key_agg = ["model_name" , "task" , "model_task" ],
10261061 agg_args = ["mean" ],
1062+ agg_multi = {"speedup_weighted" : mean_weight , "speedup_geo" : mean_geo },
10271063 keep_columns_in_index = ["suite" ],
10281064 name = "agg-suite" ,
10291065 ),
0 commit comments