Skip to content

Commit cc64994

Browse files
committed
Handle more models
1 parent 15208c6 commit cc64994

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
542542
"the onnx exporter should use.",
543543
default="",
544544
)
545+
parser.add_argument(
546+
"--ort-logs",
547+
default=False,
548+
action=BooleanOptionalAction,
549+
help="Enables onnxruntime logging when the session is created",
550+
)
545551
return parser
546552

547553

@@ -601,6 +607,7 @@ def _cmd_validate(argv: List[Any]):
601607
repeat=args.repeat,
602608
warmup=args.warmup,
603609
inputs2=args.inputs2,
610+
ort_logs=args.ort_logs,
604611
output_names=(
605612
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
606613
),

onnx_diagnostic/torch_models/validate.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def validate_model(
292292
warmup: int = 0,
293293
inputs2: int = 1,
294294
output_names: Optional[List[str]] = None,
295+
ort_logs: bool = False,
295296
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
296297
"""
297298
Validates a model.
@@ -344,6 +345,7 @@ def validate_model(
344345
this ensures that the model does support dynamism, the value is used
345346
as an increment to the first set of values (added to dimensions)
346347
:param output_names: output names the onnx exporter should use
348+
:param ort_logs: increases onnxruntime verbosity when creating the session
347349
:return: two dictionaries, one with some metrics,
348350
another one with whatever the function produces
349351
@@ -758,6 +760,7 @@ def validate_model(
758760
repeat=repeat,
759761
warmup=warmup,
760762
inputs2=inputs2,
763+
ort_logs=ort_logs,
761764
)
762765
summary.update(summary_valid)
763766

@@ -1158,6 +1161,7 @@ def validate_onnx_model(
11581161
repeat: int = 1,
11591162
warmup: int = 0,
11601163
inputs2: int = 1,
1164+
ort_logs: bool = False,
11611165
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
11621166
"""
11631167
Verifies that an onnx model produces the same
@@ -1176,6 +1180,7 @@ def validate_onnx_model(
11761180
:param inputs2: to validate the model on the second input set
11771181
to make sure the exported model supports dynamism, the value is
11781182
used as an increment added to the first set of inputs (added to dimensions)
1183+
:param ort_logs: triggers the logs for onnxruntime
11791184
:return: two dictionaries, one with some metrics,
11801185
another one with whatever the function produces
11811186
"""
@@ -1232,8 +1237,13 @@ def _mk(key, flavour=flavour):
12321237

12331238
if verbose:
12341239
print("[validate_onnx_model] runtime is onnxruntime")
1235-
cls_runtime = lambda model, providers: onnxruntime.InferenceSession(
1240+
sess_opts = onnxruntime.SessionOptions()
1241+
if ort_logs:
1242+
sess_opts.log_severity_level = 0
1243+
sess_opts.log_verbosity_level = 4
1244+
cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession(
12361245
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1246+
_o,
12371247
providers=providers,
12381248
)
12391249
elif runtime == "torch":

0 commit comments

Comments
 (0)