Skip to content

Commit f3436e8

Browse files
committed
Support option --trained
1 parent a5759f0 commit f3436e8

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def validate_model(
228228
:return: two dictionaries, one with some metrics,
229229
another one with whatever the function produces
230230
"""
231-
assert not trained, f"trained={trained} not supported yet"
232231
summary = version_summary()
233232

234233
summary.update(
@@ -270,14 +269,18 @@ def validate_model(
270269
begin = time.perf_counter()
271270
if quiet:
272271
try:
273-
data = get_untrained_model_with_inputs(model_id, verbose=verbose, task=task)
272+
data = get_untrained_model_with_inputs(
273+
model_id, verbose=verbose, task=task, same_as_pretrained=trained
274+
)
274275
except Exception as e:
275276
summary["ERR_create"] = str(e)
276277
data["ERR_create"] = e
277278
summary["time_create"] = time.perf_counter() - begin
278279
return summary, {}
279280
else:
280-
data = get_untrained_model_with_inputs(model_id, verbose=verbose, task=task)
281+
data = get_untrained_model_with_inputs(
282+
model_id, verbose=verbose, task=task, same_as_pretrained=trained
283+
)
281284

282285
if drop_inputs:
283286
if verbose:

0 commit comments

Comments
 (0)