Skip to content

Commit 00c6d07

Browse files
committed
fix aggregration
1 parent 8e86fdd commit 00c6d07

File tree

3 files changed

+159
-3
lines changed

3 files changed

+159
-3
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.0
55
+++++
66

7+
* :pr:`283`: fix historical aggregation when multiple input sets are used
78
* :pr:`282`: add tools to understand better which functions were patched
89
* :pr:`280`: fixes patches for sdpa_attention_forward for different version of transformers
910
* :pr:`278`: implements ``onnx_generate_with_genai``

_unittests/ut_helpers/test_log_helper.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,95 @@ def test_cube_sbs_with_time(self):
614614
self.assertEqual(sbs_agg.index.names, ["date", "METRICS"])
615615
self.assertEqual(sorted(sbs_agg.columns.names), ["CONF", "exporter"])
616616

617+
def test_fix_non_consistent_historical_data_no_change(self):
618+
# no change
619+
df = pandas.DataFrame(
620+
[
621+
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
622+
dict(date="2025/01/02", time_p=0.51, exporter="E1", model_s="O", model="M"),
623+
dict(date="2025/01/03", time_p=0.53, exporter="E1", model_s="O", model="M"),
624+
]
625+
)
626+
cube = CubeLogs(
627+
df, keys=["^model*", "exporter", "opt"], values=["time_p"], time="date"
628+
).load()
629+
view, _view_def = cube.view(
630+
CubeViewDef(["^model.*"], ["^time_.*"]), return_view_def=True
631+
)
632+
expected = {
633+
("time_p", pandas.Timestamp("2025-01-01 00:00:00")): {"ALL": 0.51},
634+
("time_p", pandas.Timestamp("2025-01-02 00:00:00")): {"ALL": 0.51},
635+
("time_p", pandas.Timestamp("2025-01-03 00:00:00")): {"ALL": 0.53},
636+
}
637+
self.assertEqual(expected, view.to_dict())
638+
639+
# no change
640+
df = pandas.DataFrame(
641+
[
642+
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
643+
dict(date="2025/01/02", time_p=0.51, exporter="E1", model_s="O", model="M"),
644+
dict(date="2025/01/03", time_p=0.53, exporter="E1", model_s="O", model="M"),
645+
]
646+
)
647+
cube = CubeLogs(
648+
df, keys=["^model*", "exporter", "opt"], values=["time_p"], time="date"
649+
).load()
650+
view, _view_def = cube.view(
651+
CubeViewDef(["^model.*"], ["^time_.*"], fix_aggregation_change=["model_s"]),
652+
return_view_def=True,
653+
)
654+
self.assertEqual(expected, view.to_dict())
655+
656+
def test_fix_non_consistent_historical_data_mixed_values(self):
657+
df = pandas.DataFrame(
658+
[
659+
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
660+
dict(date="2025/01/02", time_p=0.51, exporter="E1", model_s="O", model="M"),
661+
dict(date="2025/01/03", time_p=0.53, exporter="E1", model_s="A", model="M"),
662+
]
663+
)
664+
cube = CubeLogs(
665+
df, keys=["^model*", "exporter", "opt"], values=["time_p"], time="date"
666+
).load()
667+
view, _view_def = cube.view(
668+
CubeViewDef(["^model.*"], ["^time_.*"], fix_aggregation_change=["model_s"]),
669+
return_view_def=True,
670+
)
671+
raw = view.to_dict()
672+
self.assertEqual(
673+
{
674+
("time_p", pandas.Timestamp("2025-01-01 00:00:00")): {"A-O": 0.51},
675+
("time_p", pandas.Timestamp("2025-01-02 00:00:00")): {"A-O": 0.51},
676+
("time_p", pandas.Timestamp("2025-01-03 00:00:00")): {"A-O": 0.53},
677+
},
678+
raw,
679+
)
680+
681+
def test_fix_non_consistent_historical_data_mixed_nan(self):
682+
df = pandas.DataFrame(
683+
[
684+
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
685+
dict(date="2025/01/02", time_p=0.51, exporter="E1", model_s="O", model="M"),
686+
dict(date="2025/01/03", time_p=0.53, exporter="E1", model="M"),
687+
]
688+
)
689+
cube = CubeLogs(
690+
df, keys=["^model*", "exporter", "opt"], values=["time_p"], time="date"
691+
).load()
692+
view, _view_def = cube.view(
693+
CubeViewDef(["^model.*"], ["^time_.*"], fix_aggregation_change=["model_s"]),
694+
return_view_def=True,
695+
)
696+
raw = view.to_dict()
697+
self.assertEqual(
698+
{
699+
("time_p", pandas.Timestamp("2025-01-01 00:00:00")): {"O": 0.51},
700+
("time_p", pandas.Timestamp("2025-01-02 00:00:00")): {"O": 0.51},
701+
("time_p", pandas.Timestamp("2025-01-03 00:00:00")): {"O": 0.53},
702+
},
703+
raw,
704+
)
705+
617706

618707
if __name__ == "__main__":
619708
unittest.main(verbosity=2)

onnx_diagnostic/helpers/log_helper.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class CubeViewDef:
4242
:param name: name of the view, used mostly to debug
4343
:param plots: adds plot to the Excel sheet
4444
:param no_index: remove the index (but keeps the columns)
45+
:param fix_aggregation_change: a column among the keys which changes aggregation value
46+
for different dates
4547
4648
Some examples of views. First example is an aggregated view
4749
for many metrics.
@@ -106,6 +108,7 @@ def __init__(
106108
name: Optional[str] = None,
107109
no_index: bool = False,
108110
plots: bool = False,
111+
fix_aggregation_change: Optional[List["str"]] = None,
109112
):
110113
self.key_index = key_index
111114
self.values = values
@@ -123,6 +126,7 @@ def __init__(
123126
self.name = name
124127
self.no_index = no_index
125128
self.plots = plots
129+
self.fix_aggregation_change = fix_aggregation_change
126130

127131
def __repr__(self) -> str:
128132
"usual"
@@ -750,6 +754,18 @@ def view(
750754
f"values={sorted(self.values)}"
751755
)
752756

757+
if view_def.fix_aggregation_change:
758+
# before aggregation, let's fix some keys whose values changed over time
759+
assert set(view_def.fix_aggregation_change) <= set(self.keys_no_time), (
760+
f"view_def.fix_aggregation_change={view_def.fix_aggregation_change} is not "
761+
f"included in the keys {self.keys_no_time}"
762+
)
763+
data_to_process = self._fix_aggregation_change(
764+
self.data, view_def.fix_aggregation_change
765+
)
766+
else:
767+
data_to_process = self.data
768+
753769
# aggregation
754770
if key_agg:
755771
final_stack = True
@@ -763,7 +779,7 @@ def view(
763779
print(f"[CubeLogs.view] aggregation of {set_key_agg}")
764780
print(f"[CubeLogs.view] groupby {keys_no_agg}")
765781

766-
data_red = self.data[[*keys_no_agg, *values]]
782+
data_red = data_to_process[[*keys_no_agg, *values]]
767783
assert set(key_index) <= set(data_red.columns), (
768784
f"view_def.name={view_def.name!r}, "
769785
f"nnable to find {set(key_index) - set(data_red.columns)}, "
@@ -792,7 +808,7 @@ def view(
792808
key_index = self._filter_column(view_def.key_index, self.keys_time)
793809
if verbose:
794810
print(f"[CubeLogs.view] no aggregation, index={key_index}")
795-
data = self.data[[*self.keys_time, *values]]
811+
data = data_to_process[[*self.keys_time, *values]]
796812
set_all_keys = set(self.keys_time)
797813
final_stack = False
798814

@@ -829,7 +845,7 @@ def view(
829845
key_columns = sorted(set_key_columns)
830846
unique = set()
831847

832-
_md = lambda s: {k: v for k, v in self.values_for_key.items() if k in s} # noqa: E731
848+
# md = lambda s: {k: v for k, v in self.values_for_key.items() if k in s} # noqa: E731
833849
all_cols = set(key_columns) | set(key_index) | set(key_agg) | unique
834850
assert all_cols == set(self.keys_time), (
835851
f"view_def.name={view_def.name!r}, "
@@ -961,6 +977,43 @@ def view(
961977
print(f"[CubeLogs.view] -- done view {view_def.name!r}")
962978
return (piv, view_def) if return_view_def else piv
963979

980+
def _fix_aggregation_change(
981+
self, data: pandas.DataFrame, columns_to_fix: Union[str, List[str]]
982+
) -> pandas.DataFrame:
983+
"""
984+
Fixes columns used to aggregate values because their meaning changed over time.
985+
"""
986+
if not isinstance(columns_to_fix, str):
987+
for c in columns_to_fix:
988+
data = self._fix_aggregation_change(data, c)
989+
return data
990+
# Let's process one column.
991+
keys = set(self.keys_time) - {columns_to_fix}
992+
select = data[self.keys_time]
993+
select_agg = select.groupby(list(keys)).count()
994+
assert select_agg[columns_to_fix].max() <= 1, (
995+
f"Column {columns_to_fix!r} has two distinct values at least for one date\n"
996+
f"{select_agg[select_agg[columns_to_fix] > 1]}"
997+
)
998+
keys = set(self.keys_no_time) - {columns_to_fix}
999+
select = data[self.keys_no_time]
1000+
select_agg = select.groupby(list(keys), as_index=True).apply(
1001+
lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))), include_groups=False
1002+
)
1003+
select_agg = select_agg.to_frame(name=columns_to_fix)
1004+
res = pandas.merge(
1005+
data.drop([columns_to_fix], axis=1),
1006+
select_agg,
1007+
how="left",
1008+
left_on=list(keys),
1009+
right_index=True,
1010+
)
1011+
assert data.shape == res.shape, (
1012+
f"Shape should match, data.shape={data.shape}, res.shape={res.shape}, "
1013+
f"columns={data.columns} but it is now {res.columns}"
1014+
)
1015+
return res
1016+
9641017
def _dropna(
9651018
self,
9661019
data: pandas.DataFrame,
@@ -1886,6 +1939,7 @@ def make_view_def(self, name: str) -> Optional[CubeViewDef]:
18861939
* **cmd:** command lines
18871940
* **raw-short:** raw data without all the unused columns
18881941
"""
1942+
fix_aggregation_change = ["model_speedup_input_set", "model_test_with"]
18891943
fs = ["suite", "model_suite", "task", "model_name", "model_task"]
18901944
index_cols = self._filter_column(fs, self.keys_time)
18911945
assert index_cols, (
@@ -1984,6 +2038,7 @@ def mean_geo(gr):
19842038
keep_columns_in_index=["suite"],
19852039
name="agg-suite",
19862040
order=order,
2041+
fix_aggregation_change=fix_aggregation_change,
19872042
),
19882043
"agg-all": lambda: CubeViewDef(
19892044
key_index=index_cols,
@@ -2014,6 +2069,7 @@ def mean_geo(gr):
20142069
name="agg-all",
20152070
order=order,
20162071
plots=True,
2072+
fix_aggregation_change=fix_aggregation_change,
20172073
),
20182074
"disc": lambda: CubeViewDef(
20192075
key_index=index_cols,
@@ -2023,6 +2079,7 @@ def mean_geo(gr):
20232079
f_highlight=f_disc,
20242080
name="disc",
20252081
order=order,
2082+
fix_aggregation_change=fix_aggregation_change,
20262083
),
20272084
"speedup": lambda: CubeViewDef(
20282085
key_index=index_cols,
@@ -2032,6 +2089,7 @@ def mean_geo(gr):
20322089
f_highlight=f_speedup,
20332090
name="speedup",
20342091
order=order,
2092+
fix_aggregation_change=fix_aggregation_change,
20352093
),
20362094
"counts": lambda: CubeViewDef(
20372095
key_index=index_cols,
@@ -2048,6 +2106,7 @@ def mean_geo(gr):
20482106
keep_columns_in_index=["suite"],
20492107
name="peak-gpu",
20502108
order=order,
2109+
fix_aggregation_change=fix_aggregation_change,
20512110
),
20522111
"time": lambda: CubeViewDef(
20532112
key_index=index_cols,
@@ -2058,6 +2117,7 @@ def mean_geo(gr):
20582117
keep_columns_in_index=["suite"],
20592118
name="time",
20602119
order=order,
2120+
fix_aggregation_change=fix_aggregation_change,
20612121
),
20622122
"time_export": lambda: CubeViewDef(
20632123
key_index=index_cols,
@@ -2066,6 +2126,7 @@ def mean_geo(gr):
20662126
keep_columns_in_index=["suite"],
20672127
name="time_export",
20682128
order=order,
2129+
fix_aggregation_change=fix_aggregation_change,
20692130
),
20702131
"err": lambda: CubeViewDef(
20712132
key_index=index_cols,
@@ -2076,6 +2137,7 @@ def mean_geo(gr):
20762137
keep_columns_in_index=["suite"],
20772138
name="err",
20782139
order=order,
2140+
fix_aggregation_change=fix_aggregation_change,
20792141
),
20802142
"bucket-speedup": lambda: CubeViewDef(
20812143
key_index=index_cols,
@@ -2085,6 +2147,7 @@ def mean_geo(gr):
20852147
name="bucket-speedup",
20862148
f_highlight=f_bucket,
20872149
order=order,
2150+
fix_aggregation_change=fix_aggregation_change,
20882151
),
20892152
"onnx": lambda: CubeViewDef(
20902153
key_index=index_cols,
@@ -2103,6 +2166,7 @@ def mean_geo(gr):
21032166
keep_columns_in_index=["suite"],
21042167
name="onnx",
21052168
order=order,
2169+
fix_aggregation_change=fix_aggregation_change,
21062170
),
21072171
"raw-short": lambda: CubeViewDef(
21082172
key_index=self.keys_time,
@@ -2111,6 +2175,7 @@ def mean_geo(gr):
21112175
keep_columns_in_index=["suite"],
21122176
name="raw-short",
21132177
no_index=True,
2178+
fix_aggregation_change=fix_aggregation_change,
21142179
),
21152180
}
21162181

@@ -2123,6 +2188,7 @@ def mean_geo(gr):
21232188
keep_columns_in_index=["suite"],
21242189
name="cmd",
21252190
order=order,
2191+
fix_aggregation_change=fix_aggregation_change,
21262192
)
21272193

21282194
assert name in implemented_views or name in {"cmd"}, (

0 commit comments

Comments
 (0)