Skip to content

Commit 4452fc1

Browse files
committed
Add option to disable patches for torch in command line validate
1 parent bb123af commit 4452fc1

File tree

3 files changed

+60
-6
lines changed

3 files changed

+60
-6
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __call__(self, parser, namespace, values, option_string=None):
306306
value = split_items[1]
307307

308308
if value in ("True", "true", "False", "false"):
309-
d[key] = bool(value)
309+
d[key] = value in ("True", "true")
310310
continue
311311
try:
312312
d[key] = int(value)
@@ -323,6 +323,54 @@ def __call__(self, parser, namespace, values, option_string=None):
323323
setattr(namespace, self.dest, d)
324324

325325

326+
class _BoolOrParseDictPatch(argparse.Action):
327+
def __call__(self, parser, namespace, values, option_string=None):
328+
329+
if not values:
330+
return
331+
if len(values) == 1 and values[0] in (
332+
"True",
333+
"False",
334+
"true",
335+
"false",
336+
"0",
337+
"1",
338+
0,
339+
1,
340+
):
341+
setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
342+
return
343+
d = getattr(namespace, self.dest) or {}
344+
if not isinstance(d, dict):
345+
d = {
346+
"patch_sympy": d,
347+
"patch_torch": d,
348+
"patch_transformers": d,
349+
"patch_diffusers": d,
350+
}
351+
for item in values:
352+
split_items = item.split("=", 1)
353+
key = split_items[0].strip() # we remove blanks around keys, as is logical
354+
value = split_items[1]
355+
356+
if value in ("True", "true", "False", "false"):
357+
d[key] = value in ("True", "true")
358+
continue
359+
try:
360+
d[key] = int(value)
361+
continue
362+
except (TypeError, ValueError):
363+
pass
364+
try:
365+
d[key] = float(value)
366+
continue
367+
except (TypeError, ValueError):
368+
pass
369+
d[key] = _parse_json(value)
370+
371+
setattr(namespace, self.dest, d)
372+
373+
326374
def get_parser_validate() -> ArgumentParser:
327375
parser = ArgumentParser(
328376
prog="validate",
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
383431
parser.add_argument(
384432
"--patch",
385433
default=True,
386-
action=BooleanOptionalAction,
387-
help="Applies patches before exporting.",
434+
action=_BoolOrParseDictPatch,
435+
nargs="*",
436+
help="Applies patches before exporting, it can be a boolean "
437+
"to enable to disable the patches or be more finetuned. It is possible to "
438+
"disable patch for torch by adding "
439+
'--patch "patch_sympy=False" --patch "patch_torch=False", '
440+
"default is True.",
388441
)
389442
parser.add_argument(
390443
"--rewrite",

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def torch_export_patches(
361361
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
362362

363363
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
364-
if catch_constraints:
364+
if patch_torch and catch_constraints:
365365
if verbose:
366366
print("[torch_export_patches] modifies shape constraints")
367367
f_produce_guards_and_solve_constraints = (
@@ -513,7 +513,7 @@ def torch_export_patches(
513513
if verbose:
514514
print("[torch_export_patches] restored pytorch functions")
515515

516-
if stop_if_static:
516+
if patch_torch and stop_if_static:
517517
if verbose:
518518
print("[torch_export_patches] restored ShapeEnv._set_replacement")
519519

@@ -529,7 +529,7 @@ def torch_export_patches(
529529
print("[torch_export_patches] restored ShapeEnv._check_frozen")
530530
ShapeEnv._check_frozen = f_shape_env__check_frozen
531531

532-
if catch_constraints:
532+
if patch_torch and catch_constraints:
533533
# to catch or skip dynamic_shapes issues
534534
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
535535
f_produce_guards_and_solve_constraints

onnx_diagnostic/torch_models/validate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ def validate_model(
426426
print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
427427
else:
428428
print(f"[validate_model] validate model id {model_id!r}")
429+
print(f"[validate_model] patch={patch!r}")
429430
if model_options:
430431
print(f"[validate_model] model_options={model_options!r}")
431432
print(f"[validate_model] get dummy inputs with input_options={input_options}...")

0 commit comments

Comments
 (0)