Skip to content

Commit 63d2aa5

Browse files
committed
more refac
1 parent 8bd3103 commit 63d2aa5

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ def validate_model(
302302
)
303303
),
304304
)
305+
if "ERR_create" in summary:
306+
return summary, data
305307

306308
if drop_inputs:
307309
if verbose:
@@ -364,18 +366,14 @@ def validate_model(
364366
# We make a copy of the input just in case the model modifies them inplace
365367
hash_inputs = string_type(data["inputs"], with_shape=True)
366368
inputs = torch_deepcopy(data["inputs"])
367-
begin = time.perf_counter()
368-
if quiet:
369-
try:
370-
expected = data["model"](**inputs)
371-
except Exception as e:
372-
summary["ERR_run"] = str(e)
373-
data["ERR_run"] = e
374-
summary["time_run"] = time.perf_counter() - begin
375-
return summary, data
376-
else:
377-
expected = data["model"](**inputs)
378-
summary["time_run"] = time.perf_counter() - begin
369+
model = data["model"]
370+
371+
expected = _quiet_or_not_quiet(
372+
quiet, "run", summary, data, (lambda m=model, inp=inputs: m(**inp))
373+
)
374+
if "ERR_run" in summary:
375+
return summary, data
376+
379377
summary["model_expected"] = string_type(expected, with_shape=True)
380378
if verbose:
381379
print("[validate_model] done (run)")
@@ -417,18 +415,18 @@ def validate_model(
417415

418416
# We make a copy of the input just in case the model modifies them inplace
419417
inputs = torch_deepcopy(data["inputs_export"])
420-
begin = time.perf_counter()
421-
if quiet:
422-
try:
423-
expected = data["model"](**inputs)
424-
except Exception as e:
425-
summary["ERR_run_patched"] = str(e)
426-
data["ERR_run_patched"] = e
427-
summary["time_run_patched"] = time.perf_counter() - begin
428-
return summary, data
429-
else:
430-
expected = data["model"](**inputs)
431-
summary["time_run_patched"] = time.perf_counter() - begin
418+
model = data["model"]
419+
420+
expected = _quiet_or_not_quiet(
421+
quiet,
422+
"run_patched",
423+
summary,
424+
data,
425+
(lambda m=model, inp=inputs: m(**inp)),
426+
)
427+
if "ERR_run_patched" in summary:
428+
return summary, data
429+
432430
disc = max_diff(data["expected"], expected)
433431
for k, v in disc.items():
434432
summary[f"disc_patched_{k}"] = v

0 commit comments

Comments
 (0)