Skip to content

Commit 2aeb3ff

Browse files
authored
Fixes --patch argument (#232)
* Fix --patch argument * doc
1 parent 7a3158e commit 2aeb3ff

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.12
55
++++++
66

7+
* :pr:`232`: fixes ``--patch`` argument so that ``--patch=0`` works
8+
* :pr:`231`: better statistics about fusions
79
* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate
810
* :pr:`226`: fix input order for models created with modelbuilder
911

onnx_diagnostic/_command_lines_parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
581581
):
582582
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
583583
return
584+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
584585
summary, _data = validate_model(
585586
model_id=args.mid,
586587
task=args.task,
@@ -591,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
591592
use_pretrained=args.trained,
592593
dtype=args.dtype,
593594
device=args.device,
594-
patch=args.patch,
595-
rewrite=args.rewrite,
595+
patch=patch_dict,
596+
rewrite=args.rewrite and patch_dict.get("patch", True),
596597
stop_if_static=args.stop_if_static,
597598
optimization=args.opt,
598599
exporter=args.export,

onnx_diagnostic/torch_models/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def validate_model(
412412
assert not rewrite or patch_kwargs.get("patch", False), (
413413
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
414414
f"patch must be True to enable rewriting, "
415-
f"if --no-patch was specified on the command line, --no-rewrite must be added."
415+
f"if --patch=0 was specified on the command line, rewrites are disabled."
416416
)
417417
summary = version_summary()
418418
summary.update(

0 commit comments

Comments
 (0)