Skip to content

Commit 869014f

Browse files
committed
metrics
1 parent d2a78c4 commit 869014f

File tree

1 file changed

+39
-3
lines changed

1 file changed

+39
-3
lines changed

onnx_diagnostic/helpers/log_helper.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class CubeViewDef:
152152
creating the view
153153
:param agg_args: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`
154154
:param agg_kwargs: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`
155+
:param agg_multi: aggregation over multiple columns
155156
:param ignore_columns: ignore the following columns if known to overload the view
156157
:param keep_columns_in_index: keeps the columns even if there is only one unique value
157158
:param dropna: drops rows with nan if not relevant
@@ -174,6 +175,9 @@ def __init__(
174175
key_agg: Optional[Sequence[str]] = None,
175176
agg_args: Sequence[Any] = ("sum",),
176177
agg_kwargs: Optional[Dict[str, Any]] = None,
178+
agg_multi: Optional[
179+
Dict[str, Callable[[pandas.core.groupby.DataFrameGroupBy], pandas.Series]]
180+
] = None,
177181
ignore_columns: Optional[Sequence[str]] = None,
178182
keep_columns_in_index: Optional[Sequence[str]] = None,
179183
dropna: bool = True,
@@ -188,6 +192,7 @@ def __init__(
188192
self.key_agg = key_agg
189193
self.agg_args = agg_args
190194
self.agg_kwargs = agg_kwargs
195+
self.agg_multi = agg_multi
191196
self.dropna = dropna
192197
self.ignore_columns = ignore_columns
193198
self.keep_columns_in_index = keep_columns_in_index
@@ -468,6 +473,7 @@ def view(
468473
)
469474

470475
if key_agg:
476+
final_stack = True
471477
key_index = [
472478
c
473479
for c in self._filter_column(view_def.key_index, self.keys_time)
@@ -483,14 +489,22 @@ def view(
483489
f"selected={pprint.pformat(sorted(data_red.columns))},\n--\n"
484490
f"keys={pprint.pformat(sorted(self.keys_time))}"
485491
)
486-
data = data_red.groupby(keys_no_agg, as_index=False, dropna=False).agg(
487-
*view_def.agg_args, **(view_def.agg_kwargs or {})
488-
)
492+
grouped_data = data_red.groupby(keys_no_agg, as_index=True, dropna=False)
493+
data = grouped_data.agg(*view_def.agg_args, **(view_def.agg_kwargs or {}))
494+
if view_def.agg_multi:
495+
append = []
496+
for k, f in view_def.agg_multi.items():
497+
cv = grouped_data.apply(f, include_groups=False)
498+
append.append(cv.to_frame(k))
499+
data = pandas.concat([data, *append], axis=1)
489500
set_all_keys = set(keys_no_agg)
501+
values = list(data.columns)
502+
data = data.reset_index(drop=False)
490503
else:
491504
key_index = self._filter_column(view_def.key_index, self.keys_time)
492505
data = self.data[[*self.keys_time, *values]]
493506
set_all_keys = set(self.keys_time)
507+
final_stack = False
494508

495509
assert set(key_index) <= set_all_keys, (
496510
f"view_def.name={view_def.name!r}, "
@@ -580,8 +594,17 @@ def view(
580594
piv = data.pivot(index=key_index[::-1], columns=key_columns, values=values)
581595
if isinstance(piv, pandas.Series):
582596
piv = piv.to_frame(name="series")
597+
names = list(piv.columns.names)
598+
assert (
599+
"METRICS" not in names
600+
), f"Not implemented when a level METRICS already exists {names!r}"
601+
names[0] = "METRICS"
602+
piv.columns = piv.columns.set_names(names)
603+
if final_stack:
604+
piv = piv.stack("METRICS")
583605
if view_def.transpose:
584606
piv = piv.T
607+
585608
return (piv, view_def) if return_view_def else piv
586609

587610
def _dropna(
@@ -1015,6 +1038,18 @@ def make_view_def(self, name: str) -> CubeViewDef:
10151038
)
10161039
)
10171040

1041+
def mean_weight(gr):
1042+
weight = gr["time_latency_eager"]
1043+
x = gr["speedup"]
1044+
if x.shape[0] == 0:
1045+
return np.nan
1046+
div = weight.sum()
1047+
return (x * weight).sum() / div
1048+
1049+
def mean_geo(gr):
1050+
x = gr["speedup"]
1051+
return np.exp(np.log(x.dropna()).mean())
1052+
10181053
implemented_views = {
10191054
"agg-suite": lambda: CubeViewDef(
10201055
key_index=index_cols,
@@ -1024,6 +1059,7 @@ def make_view_def(self, name: str) -> CubeViewDef:
10241059
ignore_unique=True,
10251060
key_agg=["model_name", "task", "model_task"],
10261061
agg_args=["mean"],
1062+
agg_multi={"speedup_weighted": mean_weight, "speedup_geo": mean_geo},
10271063
keep_columns_in_index=["suite"],
10281064
name="agg-suite",
10291065
),

0 commit comments

Comments
 (0)