Skip to content

Commit 35977fe

Browse files
committed
better design
1 parent e96d7ec commit 35977fe

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import inspect
33
import os
4-
from typing import Any, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55
import time
66
import onnx
77
import torch
@@ -180,6 +180,23 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
180180
return summary
181181

182182

183+
def _quiet_or_not_quiet(
184+
quiet: bool, suffix: str, summary: Dict[str, Any], data: Dict[str, Any], fct: Callable
185+
) -> Any:
186+
begin = time.perf_counter()
187+
if quiet:
188+
try:
189+
return fct()
190+
except Exception as e:
191+
summary[f"ERR_{suffix}"] = str(e)
192+
data[f"ERR_{suffix}"] = e
193+
summary[f"time_{suffix}"] = time.perf_counter() - begin
194+
return summary, {}
195+
res = fct()
196+
summary[f"time_{suffix}"] = time.perf_counter() - begin
197+
return res
198+
199+
183200
def validate_model(
184201
model_id: str,
185202
task: Optional[str] = None,
@@ -266,21 +283,19 @@ def validate_model(
266283
print("[validate_model] get dummy inputs...")
267284
summary["model_id"] = model_id
268285

269-
begin = time.perf_counter()
270-
if quiet:
271-
try:
272-
data = get_untrained_model_with_inputs(
273-
model_id, verbose=verbose, task=task, same_as_pretrained=trained
286+
data = _quiet_or_not_quiet(
287+
quiet,
288+
"create",
289+
summary,
290+
None,
291+
(
292+
lambda mid=model_id, v=verbose, task=task, tr=trained: (
293+
get_untrained_model_with_inputs(
294+
mid, verbose=v, task=task, same_as_pretrained=tr
295+
)
274296
)
275-
except Exception as e:
276-
summary["ERR_create"] = str(e)
277-
data["ERR_create"] = e
278-
summary["time_create"] = time.perf_counter() - begin
279-
return summary, {}
280-
else:
281-
data = get_untrained_model_with_inputs(
282-
model_id, verbose=verbose, task=task, same_as_pretrained=trained
283-
)
297+
),
298+
)
284299

285300
if drop_inputs:
286301
if verbose:
@@ -316,7 +331,6 @@ def validate_model(
316331
data["inputs"] = to_any(data["inputs"], device) # type: ignore
317332
summary["model_device"] = str(device)
318333

319-
summary["time_create"] = time.perf_counter() - begin
320334
for k in ["task", "size", "n_weights"]:
321335
summary[f"model_{k.replace('_','')}"] = data[k]
322336
summary["model_inputs"] = string_type(data["inputs"], with_shape=True)

0 commit comments

Comments
 (0)