Skip to content

Commit d54283f

Browse files
committed
fix test
1 parent 089b1d3 commit d54283f

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

onnx_diagnostic/helpers/bench_run.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def _cmd_line(script_name: str, **kwargs: Dict[str, Union[str, int, float]]) ->
8686
return args
8787

8888

89-
def _extract_metrics(text: str) -> Dict[str, str]:
89+
def _extract_metrics(text: str) -> Dict[str, Union[str, int, float]]:
9090
reg = re.compile(":(.*?),(.*.?);")
9191
res = reg.findall(text)
9292
if len(res) == 0:
9393
return {}
9494
kw = dict(res)
95-
new_kw = {}
95+
new_kw: Dict[str, Any] = {}
9696
for k, w in kw.items():
9797
assert isinstance(k, str) and isinstance(
9898
w, str
@@ -159,7 +159,7 @@ def run_benchmark(
159159
summary: Optional[Callable] = None,
160160
timeout: int = 600,
161161
missing: Optional[Dict[str, Union[str, Callable]]] = None,
162-
) -> List[Dict[str, Union[str, int, float, Tuple[int, int]]]]:
162+
) -> List[Dict[str, Union[str, int, float]]]:
163163
"""
164164
Runs a script multiple times and extract information from the output
165165
following the pattern ``:<metric>,<value>;``.
@@ -188,7 +188,7 @@ def run_benchmark(
188188
else:
189189
loop = configs
190190

191-
data: List[Dict[str, Union[str, int, float, Tuple[int, int]]]] = []
191+
data: List[Dict[str, Union[str, int, float]]] = []
192192
for iter_loop, config in enumerate(loop):
193193
if iter_loop < start:
194194
continue
@@ -266,23 +266,32 @@ def run_benchmark(
266266
metrics.update(config)
267267
if filename_out and os.path.exists(filename_out):
268268
if "model_name" in metrics:
269+
assert isinstance(
270+
metrics["model_name"], str
271+
), f"unexpected type {type(metrics['model_name'])}"
269272
new_name = f"{filename_out}.{_clean_string(metrics['model_name'])}"
270273
os.rename(filename_out, new_name)
271274
filename_out = new_name
272275
metrics["file.stdout"] = filename_out
273276
if filename_err and os.path.exists(filename_err):
274277
if "model_name" in metrics:
278+
assert isinstance(
279+
metrics["model_name"], str
280+
), f"unexpected type {type(metrics['model_name'])}"
275281
new_name = f"{filename_err}.{_clean_string(metrics['model_name'])}"
276282
os.rename(filename_err, new_name)
277283
filename_err = new_name
278284
metrics["file.stderr"] = filename_err
279285
metrics["DATE"] = f"{datetime.now():%Y-%m-%d}"
280-
metrics["ITER"] = iter_loop
286+
metrics["ITER"] = str(iter_loop)
281287
metrics["TIME_ITER"] = time.perf_counter() - begin
282288
metrics["ERROR"] = _clean_string(serr)
283289
metrics["ERR_stdout"] = _clean_string(sout)
284290
if metrics["ERROR"]:
285291
metrics["ERR_std"] = metrics["ERROR"]
292+
assert isinstance(
293+
metrics["ERROR"], str
294+
), f"unexpected type {type(metrics['ERROR'])}"
286295
if "CUDA out of memory" in metrics["ERROR"]:
287296
metrics["ERR_CUDA_OOM"] = 1
288297
if "Cannot access gated repo for url" in metrics["ERROR"]:
@@ -348,8 +357,8 @@ def make_configs(
348357
drop: Optional[Set[str]] = None,
349358
replace: Optional[Dict[str, str]] = None,
350359
last: Optional[List[str]] = None,
351-
filter_function: Optional[Callable[Dict[str, Any], bool]] = None,
352-
) -> List[Dict[str, Any]]:
360+
filter_function: Optional[Callable[[Dict[str, Union[str, int, float]]], bool]] = None,
361+
) -> List[Dict[str, Union[str, int, float]]]:
353362
"""
354363
Creates all the configurations based on the command line arguments.
355364
@@ -383,14 +392,14 @@ def make_configs(
383392
for k in last:
384393
if k not in kwargs_:
385394
continue
386-
v = kwargs[k]
395+
v = kwargs[k] # type: ignore
387396
if isinstance(v, str):
388397
args.append([(k, s) for s in v.split(",")])
389398
else:
390399
args.append([(k, v)])
391400

392401
configs = list(itertools.product(*args))
393-
confs = [dict(c) for c in configs]
402+
confs: List[Dict[str, Union[int, float, str]]] = [dict(c) for c in configs]
394403
if filter_function:
395404
confs = [c for c in confs if filter_function(c)]
396405
return confs

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
accelerate # transformers/src/transformers/modeling_utils.py -> init_empty_weights missing if this package is not installed
12
black
23
diffusers>=0.30.0
34
furo

0 commit comments

Comments
 (0)