Skip to content

Commit 8bb6603

Browse files
committed
fix sbs
1 parent b1f3aeb commit 8bb6603

File tree

3 files changed

+111
-16
lines changed

3 files changed

+111
-16
lines changed

_unittests/ut_helpers/test_log_helper.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def test_historical_cube_time_mask(self):
470470
cube = CubeLogs(df, keys=["^m_*", "exporter"], time="date").load()
471471
cube.to_excel(output, views=["time_p"], time_mask=True, verbose=1)
472472

473-
def test_cube_sbs(self):
473+
def test_cube_sbs_no_time(self):
474474
df = pandas.DataFrame(
475475
[
476476
dict(
@@ -518,10 +518,100 @@ def test_cube_sbs(self):
518518
dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))
519519
)
520520
self.assertEqual(sbs.shape, (4, 9))
521-
self.assertEqual(sbs.index.names, ["METRICS", "m_name"])
521+
self.assertEqual(sbs.index.names, ["METRICS", "m_name", "date"])
522522
self.assertEqual(sorted(sbs.columns.names), ["CONF", "exporter"])
523523
self.assertEqual(sbs_agg.shape, (2, 9))
524-
self.assertEqual(sbs_agg.index.names, ["METRICS"])
524+
self.assertEqual(sbs_agg.index.names, ["date", "METRICS"])
525+
self.assertEqual(sorted(sbs_agg.columns.names), ["CONF", "exporter"])
526+
527+
def test_cube_sbs_with_time(self):
528+
df = pandas.DataFrame(
529+
[
530+
dict(
531+
date="2025/01/01",
532+
time_p=0.51,
533+
exporter="E1",
534+
opt="O",
535+
perf=3.7,
536+
m_name="A",
537+
m_cls="CA",
538+
),
539+
dict(
540+
date="2025/01/01",
541+
time_p=0.51,
542+
perf=3.4,
543+
exporter="E2",
544+
opt="O",
545+
m_name="A",
546+
m_cls="CA",
547+
),
548+
dict(
549+
date="2025/01/01",
550+
time_p=0.71,
551+
perf=3.5,
552+
exporter="E2",
553+
opt="O",
554+
m_name="B",
555+
m_cls="CA",
556+
),
557+
dict(
558+
date="2025/01/01",
559+
time_p=0.71,
560+
perf=3.6,
561+
exporter="E2",
562+
opt="K",
563+
m_name="B",
564+
m_cls="CA",
565+
),
566+
dict(
567+
date="2025/01/02",
568+
time_p=0.51,
569+
exporter="E1",
570+
opt="O",
571+
perf=3.7,
572+
m_name="A",
573+
m_cls="CA",
574+
),
575+
dict(
576+
date="2025/01/02",
577+
time_p=0.51,
578+
perf=3.4,
579+
exporter="E2",
580+
opt="O",
581+
m_name="A",
582+
m_cls="CA",
583+
),
584+
dict(
585+
date="2025/01/02",
586+
time_p=0.71,
587+
perf=3.5,
588+
exporter="E2",
589+
opt="O",
590+
m_name="B",
591+
m_cls="CA",
592+
),
593+
dict(
594+
date="2025/01/02",
595+
time_p=0.71,
596+
perf=3.6,
597+
exporter="E2",
598+
opt="K",
599+
m_name="B",
600+
m_cls="CA",
601+
),
602+
]
603+
)
604+
cube = CubeLogs(
605+
df, keys=["^m_*", "exporter", "opt"], values=["time_p", "perf"], time="date"
606+
).load()
607+
sbs, sbs_agg = cube.sbs(
608+
dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))
609+
)
610+
self.assertEqual(sbs.shape, (8, 9))
611+
self.assertEqual(sbs.index.names, ["METRICS", "m_name", "date"])
612+
self.assertEqual(sorted(sbs.columns.names), ["CONF", "exporter"])
613+
self.assertEqual(sbs_agg.shape, (4, 9))
614+
self.assertEqual(sbs_agg.index.names, ["date", "METRICS"])
525615
self.assertEqual(sorted(sbs_agg.columns.names), ["CONF", "exporter"])
526616

527617

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ def get_parser_agg() -> ArgumentParser:
803803
Defines an exporter to compare to another, there must be at least
804804
two arguments defined with --sbs. Example:
805805
--sbs dynamo:exporter=onnx-dynamo,opt=ir,attn_impl=eager
806-
--sbs cusom:exporter=custom,opt=default,attn_impl=eager
806+
--sbs custom:exporter=custom,opt=default,attn_impl=eager
807807
"""
808808
),
809809
action=_ParseNamedDict,

onnx_diagnostic/helpers/log_helper.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,14 +1196,14 @@ def to_excel(
11961196
if verbose:
11971197
for k, v in sbs.items():
11981198
print(f"[CubeLogs.to_excel] sbs {k}: {v}")
1199+
name = "∧".join(sbs)
11991200
sbs_raw, sbs_agg = self.sbs(sbs)
12001201
if verbose:
12011202
print(f"[CubeLogs.to_excel] add sheet {name!r} with shape {sbs_raw.shape}")
12021203
print(
12031204
f"[CubeLogs.to_excel] add sheet '{name}-AGG' "
12041205
f"with shape {sbs_agg.shape}"
12051206
)
1206-
name = "∧".join(sbs)
12071207
sbs_raw = sbs_raw.reset_index(drop=False)
12081208
sbs_raw.to_excel(
12091209
writer,
@@ -1253,8 +1253,8 @@ def to_excel(
12531253

12541254
if verbose:
12551255
print(f"[CubeLogs.to_excel] applies style to {output!r}")
1256-
apply_excel_style( # type: ignore[arg-type]
1257-
writer, f_highlights, time_mask_view=time_mask_view, verbose=verbose
1256+
apply_excel_style(
1257+
writer, f_highlights, time_mask_view=time_mask_view, verbose=verbose # type: ignore[arg-type]
12581258
)
12591259
if verbose:
12601260
print(f"[CubeLogs.to_excel] done with {len(views)} views")
@@ -1346,8 +1346,14 @@ def sbs(
13461346
new_data = pandas.concat(data_list, axis=0)
13471347
cube = self.clone(new_data, keys=[*self.keys_no_time, column_name])
13481348
key_index = set(self.keys_time) - {*columns_index, column_name} # type: ignore[misc]
1349-
view = CubeViewDef(key_index=set(key_index), name="sbs", values=cube.values) # type: ignore[arg-type]
1349+
view = CubeViewDef(
1350+
key_index=set(key_index), # type: ignore[arg-type]
1351+
name="sbs",
1352+
values=cube.values,
1353+
keep_columns_in_index=[self.time],
1354+
)
13501355
view_res = cube.view(view)
1356+
assert isinstance(view_res, pandas.DataFrame), "not needed but mypy complains"
13511357

13521358
# add metrics
13531359
index_column_name = list(view_res.columns.names).index(column_name)
@@ -1420,20 +1426,19 @@ def _mkc(m, s):
14201426
columns_to_add.append(nas)
14211427
sum_columns.extend(nas.columns)
14221428

1423-
# aggregated metrics
1424-
aggs = {
1425-
**{k: "mean" for k in mean_columns}, # noqa: C420
1426-
**{k: "sum" for k in sum_columns}, # noqa: C420
1427-
}
14281429
view_res = pandas.concat([view_res, *columns_to_add], axis=1)
14291430
res = view_res.stack("METRICS", future_stack=True) # type: ignore[union-attr]
14301431
res = res.reorder_levels(
14311432
[res.index.nlevels - 1, *list(range(res.index.nlevels - 1))]
14321433
).sort_index()
14331434

1434-
view_res["GROUPBY"] = "A"
1435-
flat = view_res.groupby("GROUPBY").agg(aggs).reset_index(drop=True)
1436-
flat = flat.stack("METRICS", future_stack=True).droplevel(None, axis=0)
1435+
# aggregated metrics
1436+
aggs = {
1437+
**{k: "mean" for k in mean_columns}, # noqa: C420
1438+
**{k: "sum" for k in sum_columns}, # noqa: C420
1439+
}
1440+
flat = view_res.groupby(self.time).agg(aggs)
1441+
flat = flat.stack("METRICS", future_stack=True)
14371442
return res, flat
14381443

14391444

0 commit comments

Comments
 (0)