Skip to content

Commit 930b984

Browse files
committed
excelé
1 parent eb15c53 commit 930b984

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

_unittests/ut_helpers/test_log_helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def test_cube_logs_load_df(self):
5757
self.assertEqual((3, df.shape[1] + 1), cube.shape)
5858
self.assertEqual(set(cube.columns), {*df.columns, "speedup"})
5959

60+
@hide_stdout()
61+
def test_cube_logs_load_dfdf(self):
62+
df = self.df1()
63+
cube = CubeLogs([df, df], recent=True)
64+
cube.load(verbose=1)
65+
self.assertEqual((3, 10), cube.shape)
66+
6067
@hide_stdout()
6168
def test_cube_logs_load_list(self):
6269
cube = CubeLogs(
@@ -174,6 +181,11 @@ def test_enumerate_csv_files(self):
174181
for df in dfs:
175182
open_dataframe(df)
176183

184+
cube = CubeLogs(data, recent=True)
185+
cube.load(verbose=1)
186+
self.assertEqual((3, 11), cube.shape)
187+
self.assertIn("RAWFILENAME", cube.data.columns)
188+
177189

178190
if __name__ == "__main__":
179191
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def test_bypass_onnx_export_tiny_llm_official_full(self):
102102
self.assertEqual(
103103
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
104104
)
105-
with torch_export_patches(patch_transformers=True, verbose=1) as modificator:
105+
with torch_export_patches(
106+
patch_transformers=True, verbose=1, stop_if_static=1
107+
) as modificator:
106108
new_inputs = modificator(copy.deepcopy(inputs))
107109
ep = torch.onnx.export(
108110
model,

onnx_diagnostic/helpers/log_helper.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import re
55
import zipfile
66
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
7-
from .helper import string_sig
7+
import numpy as np
88
import pandas
99
from pandas.api.types import is_numeric_dtype
10+
from .helper import string_sig
1011

1112

1213
def enumerate_csv_files(
@@ -197,6 +198,27 @@ def load(self, verbose: int = 0):
197198
if verbose:
198199
print(f"[CubeLogs.load] load from list of dicts, n={len(self._data)}")
199200
self.data = pandas.DataFrame(self._data)
201+
elif isinstance(self._data, list) and all(
202+
isinstance(r, pandas.DataFrame) for r in self._data
203+
):
204+
if verbose:
205+
print(f"[CubeLogs.load] load from list of DataFrame, n={len(self._data)}")
206+
self.data = pandas.concat(self._data, axis=0)
207+
elif isinstance(self._data, list):
208+
cubes = []
209+
for item in enumerate_csv_files(self._data, verbose=verbose):
210+
df = open_dataframe(item)
211+
cube = CubeLogs(
212+
df,
213+
time=self._time,
214+
keys=self._keys,
215+
values=self._values,
216+
ignored=self._ignored,
217+
recent=self.recent,
218+
)
219+
cube.load()
220+
cubes.append(cube.data)
221+
self.data = pandas.concat(cubes, axis=0)
200222
else:
201223
raise NotImplementedError(
202224
f"Not implemented with the provided data (type={type(self._data)})"
@@ -281,16 +303,25 @@ def _preprocess(self):
281303
last = self.values[0]
282304
gr = self.data[[self.time, *self.keys, last]].groupby([self.time, *self.keys]).count()
283305
gr = gr[gr[last] > 1]
284-
assert gr.shape[0] == 0, f"There are duplicated rows:\n{gr}"
285306
if self.recent:
286-
gr = self.data[[*self.keys, self.time]].groupby(self.keys, as_index=False).max()
287-
filtered = pandas.merge(self.data, gr, on=[self.time, *self.keys])
307+
cp = self.data.copy()
308+
assert (
309+
"__index__" not in cp.columns
310+
), f"'__index__' should not be a column in {cp.columns}"
311+
cp["__index__"] = np.arange(cp.shape[0])
312+
gr = (
313+
cp[[*self.keys, self.time, "__index__"]]
314+
.groupby(self.keys, as_index=False)
315+
.max()
316+
)
317+
filtered = pandas.merge(cp, gr, on=[self.time, "__index__", *self.keys])
288318
assert filtered.shape[0] <= self.data.shape[0], (
289319
f"Keeping the latest row brings more row {filtered.shape} "
290320
f"(initial is {self.data.shape})."
291321
)
292-
self.data = filtered
322+
self.data = filtered.drop("__index__", axis=1)
293323
else:
324+
assert gr.shape[0] == 0, f"There are duplicated rows:\n{gr}"
294325
gr = self.data[[*self.keys, self.time]].groupby(self.keys).count()
295326
gr = gr[gr[self.time] > 1]
296327
assert (

0 commit comments

Comments
 (0)