Skip to content

Commit 3ee9d8d

Browse files
committed
tentative
1 parent 00c6d07 commit 3ee9d8d

File tree

2 files changed

+93
-11
lines changed

2 files changed

+93
-11
lines changed

_unittests/ut_helpers/test_log_helper.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,32 @@ def test_fix_non_consistent_historical_data_no_change(self):
653653
)
654654
self.assertEqual(expected, view.to_dict())
655655

656-
def test_fix_non_consistent_historical_data_mixed_values(self):
656+
def test_fix_non_consistent_historical_data_mixed_values1(self):
657+
df = pandas.DataFrame(
658+
[
659+
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
660+
dict(date="2025/01/02", time_p=0.51, exporter="E1", model_s="O", model="M"),
661+
dict(date="2025/01/03", time_p=0.53, exporter="E1", model_s="A", model="M"),
662+
]
663+
)
664+
cube = CubeLogs(
665+
df, keys=["^model*", "exporter", "opt"], values=["time_p"], time="date"
666+
).load()
667+
view, _view_def = cube.view(
668+
CubeViewDef(["^model.*"], ["^time_.*"], fix_aggregation_change=["model_s"]),
669+
return_view_def=True,
670+
)
671+
raw = view.to_dict()
672+
self.assertEqual(
673+
{
674+
("time_p", pandas.Timestamp("2025-01-01 00:00:00")): {"A-O": 0.51},
675+
("time_p", pandas.Timestamp("2025-01-02 00:00:00")): {"A-O": 0.51},
676+
("time_p", pandas.Timestamp("2025-01-03 00:00:00")): {"A-O": 0.53},
677+
},
678+
raw,
679+
)
680+
681+
def test_fix_non_consistent_historical_data_mixed_values2(self):
657682
df = pandas.DataFrame(
658683
[
659684
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
@@ -703,6 +728,37 @@ def test_fix_non_consistent_historical_data_mixed_nan(self):
703728
raw,
704729
)
705730

731+
def test_fix_non_consistent_historical_data_nan(self):
732+
df = pandas.DataFrame(
733+
[
734+
dict(date="2025/01/01", time_p=0.51, exporter="E1", model_s="O", model="M"),
735+
dict(date="2025/01/02", time_p=0.51, exporter="E1", model_s="O", model="M"),
736+
dict(date="2025/01/03", time_p=0.53, exporter="E1", model_s="A", model="M"),
737+
dict(date="2025/01/01", time_p=0.51, exporter="E2", model="M"),
738+
dict(date="2025/01/02", time_p=0.51, exporter="E2", model="M"),
739+
dict(date="2025/01/03", time_p=0.53, exporter="E2", model="M"),
740+
]
741+
)
742+
cube = CubeLogs(
743+
df, keys=["^model*", "exporter", "opt"], values=["time_p"], time="date"
744+
).load()
745+
view, _view_def = cube.view(
746+
CubeViewDef(["^model.*"], ["^time_.*"], fix_aggregation_change=["model_s"]),
747+
return_view_def=True,
748+
)
749+
raw = view.reset_index(drop=True).fillna("NAN").to_dict(orient="list")
750+
self.assertEqual(
751+
{
752+
("time_p", "E1", pandas.Timestamp("2025-01-01 00:00:00")): ["NAN", 0.51],
753+
("time_p", "E1", pandas.Timestamp("2025-01-02 00:00:00")): ["NAN", 0.51],
754+
("time_p", "E1", pandas.Timestamp("2025-01-03 00:00:00")): ["NAN", 0.53],
755+
("time_p", "E2", pandas.Timestamp("2025-01-01 00:00:00")): [0.51, "NAN"],
756+
("time_p", "E2", pandas.Timestamp("2025-01-02 00:00:00")): [0.51, "NAN"],
757+
("time_p", "E2", pandas.Timestamp("2025-01-03 00:00:00")): [0.53, "NAN"],
758+
},
759+
raw,
760+
)
761+
706762

707763
if __name__ == "__main__":
708764
unittest.main(verbosity=2)

onnx_diagnostic/helpers/log_helper.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -754,14 +754,13 @@ def view(
754754
f"values={sorted(self.values)}"
755755
)
756756

757-
if view_def.fix_aggregation_change:
757+
if view_def.fix_aggregation_change and (
758+
set(view_def.fix_aggregation_change) & set(self.keys_no_time)
759+
):
758760
# before aggregation, let's fix some keys whose values changed over time
759-
assert set(view_def.fix_aggregation_change) <= set(self.keys_no_time), (
760-
f"view_def.fix_aggregation_change={view_def.fix_aggregation_change} is not "
761-
f"included in the keys {self.keys_no_time}"
762-
)
763761
data_to_process = self._fix_aggregation_change(
764-
self.data, view_def.fix_aggregation_change
762+
self.data,
763+
list(set(view_def.fix_aggregation_change) & set(self.keys_no_time)),
765764
)
766765
else:
767766
data_to_process = self.data
@@ -908,7 +907,7 @@ def view(
908907
f"key={sorted(key_columns)}, key_agg={key_agg}, values={sorted(values)}, "
909908
f"columns={sorted(data.columns)}, ignored={view_def.ignore_columns}, "
910909
f"not unique={set(data.columns) - unique}"
911-
f"\n--\n{not_unique.head()}"
910+
f"\n--\n{not_unique.head(10)}"
912911
)
913912

914913
# pivot
@@ -978,10 +977,20 @@ def view(
978977
return (piv, view_def) if return_view_def else piv
979978

980979
def _fix_aggregation_change(
981-
self, data: pandas.DataFrame, columns_to_fix: Union[str, List[str]]
980+
self,
981+
data: pandas.DataFrame,
982+
columns_to_fix: Union[str, List[str]],
983+
overwrite_or_merge: bool = True,
982984
) -> pandas.DataFrame:
983985
"""
984986
Fixes columns used to aggregate values because their meaning changed over time.
987+
988+
:param data: data to fix
989+
:param columns_to_fix: list of columns to fix
990+
:param overwrite_or_merge: if True, overwrite all values by the concatenation
991+
of all existing values, if merge, merges existing values found
992+
and grouped by the other keys
993+
:return: fixed data
985994
"""
986995
if not isinstance(columns_to_fix, str):
987996
for c in columns_to_fix:
@@ -995,6 +1004,10 @@ def _fix_aggregation_change(
9951004
f"Column {columns_to_fix!r} has two distinct values at least for one date\n"
9961005
f"{select_agg[select_agg[columns_to_fix] > 1]}"
9971006
)
1007+
1008+
# unique value (to fill NaN)
1009+
unique = "-".join(sorted(set(data[columns_to_fix].dropna())))
1010+
9981011
keys = set(self.keys_no_time) - {columns_to_fix}
9991012
select = data[self.keys_no_time]
10001013
select_agg = select.groupby(list(keys), as_index=True).apply(
@@ -1008,9 +1021,22 @@ def _fix_aggregation_change(
10081021
left_on=list(keys),
10091022
right_index=True,
10101023
)
1011-
assert data.shape == res.shape, (
1024+
val = f"?{unique}?"
1025+
res[columns_to_fix] = res[columns_to_fix].fillna(val).replace("", val)
1026+
assert (
1027+
data.shape == res.shape
1028+
and sorted(data.columns) == sorted(res.columns)
1029+
and sorted(data.index) == sorted(res.index)
1030+
), (
10121031
f"Shape should match, data.shape={data.shape}, res.shape={res.shape}, "
1013-
f"columns={data.columns} but it is now {res.columns}"
1032+
f"lost={set(data.columns) - set(res.columns)}, "
1033+
f"added={set(res.columns) - set(data.columns)}"
1034+
)
1035+
res = res[data.columns]
1036+
assert data.columns.equals(res.columns) and data.index.equals(res.index), (
1037+
f"Columns or index mismatch "
1038+
f"data.columns.equals(res.columns)={data.columns.equals(res.columns)}, "
1039+
f"data.index.equals(res.columns)={data.index.equals(res.columns)}, "
10141040
)
10151041
return res
10161042

0 commit comments

Comments
 (0)