Skip to content

Commit 34b798c

Browse files
committed
fixes
1 parent a9a1fd9 commit 34b798c

File tree

3 files changed

+105
-23
lines changed

3 files changed

+105
-23
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -559,12 +559,6 @@ def get_parser_validate() -> ArgumentParser:
559559
help="Avoids raising an exception when an input sets does not work with "
560560
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
561561
)
562-
parser.add_argument(
563-
"--sample-code",
564-
default="",
565-
help="Generates a sample code to export a model without "
566-
"without this package.\nExample --sample-code=export_sample",
567-
)
568562
return parser
569563

570564

@@ -630,14 +624,101 @@ def _cmd_validate(argv: List[Any]):
630624
output_names=(
631625
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
632626
),
633-
sample_code=args.sample_code,
634627
)
635628
print("")
636629
print("-- summary --")
637630
for k, v in sorted(summary.items()):
638631
print(f":{k},{v};")
639632

640633

634+
def _cmd_export_sample(argv: List[Any]):
635+
from .helpers import string_type
636+
from .torch_models.validate import get_inputs_for_task, _make_folder_name
637+
from .torch_models.code_sample import code_sample
638+
from .tasks import supported_tasks
639+
640+
parser = get_parser_validate()
641+
args = parser.parse_args(argv[1:])
642+
if not args.task and not args.mid:
643+
print("-- list of supported tasks:")
644+
print("\n".join(supported_tasks()))
645+
elif not args.mid:
646+
data = get_inputs_for_task(args.task)
647+
if args.verbose:
648+
print(f"task: {args.task}")
649+
max_length = max(len(k) for k in data["inputs"]) + 1
650+
print("-- inputs")
651+
for k, v in data["inputs"].items():
652+
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
653+
print("-- dynamic_shapes")
654+
for k, v in data["dynamic_shapes"].items():
655+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
656+
else:
657+
# Let's skip any invalid combination if known to be unsupported
658+
if (
659+
"onnx" not in (args.export or "")
660+
and "custom" not in (args.export or "")
661+
and (args.opt or "")
662+
):
663+
print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}")
664+
return
665+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
666+
code = code_sample(
667+
model_id=args.mid,
668+
task=args.task,
669+
do_run=args.run,
670+
verbose=args.verbose,
671+
quiet=args.quiet,
672+
same_as_pretrained=args.same_as_trained,
673+
use_pretrained=args.trained,
674+
dtype=args.dtype,
675+
device=args.device,
676+
patch=patch_dict,
677+
rewrite=args.rewrite and patch_dict.get("patch", True),
678+
stop_if_static=args.stop_if_static,
679+
optimization=args.opt,
680+
exporter=args.export,
681+
dump_folder=args.dump_folder,
682+
drop_inputs=None if not args.drop else args.drop.split(","),
683+
input_options=args.iop,
684+
model_options=args.mop,
685+
subfolder=args.subfolder,
686+
opset=args.opset,
687+
runtime=args.runtime,
688+
output_names=(
689+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
690+
),
691+
)
692+
if args.dump_folder:
693+
os.makedirs(args.dump_folder, exist_ok=True)
694+
name = (
695+
_make_folder_name(
696+
model_id=args.model_id,
697+
exporter=args.exporter,
698+
optimization=args.optimization,
699+
dtype=args.dtype,
700+
device=args.device,
701+
subfolder=args.subfolder,
702+
opset=args.opset,
703+
drop_inputs=None if not args.drop else args.drop.split(","),
704+
same_as_pretrained=args.same_as_pretrained,
705+
use_pretrained=args.use_pretrained,
706+
task=args.task,
707+
).replace("/", "-")
708+
+ ".py"
709+
)
710+
fullname = os.path.join(args.dump_folder, name)
711+
if args.verbose:
712+
print(f"-- prints code in {fullname!r}")
713+
print("--")
714+
with open(fullname, "w") as f:
715+
f.write(code)
716+
if args.verbose:
717+
print("-- done")
718+
else:
719+
print(code)
720+
721+
641722
def get_parser_stats() -> ArgumentParser:
642723
parser = ArgumentParser(
643724
prog="stats",
@@ -967,14 +1048,15 @@ def get_main_parser() -> ArgumentParser:
9671048
Type 'python -m onnx_diagnostic <cmd> --help'
9681049
to get help for a specific command.
9691050
970-
agg - aggregates statistics from multiple files
971-
config - prints a configuration for a model id
972-
find - find node consuming or producing a result
973-
lighten - makes an onnx model lighter by removing the weights,
974-
print - prints the model on standard output
975-
stats - produces statistics on a model
976-
unlighten - restores an onnx model produces by the previous experiment
977-
validate - validate a model
1051+
agg - aggregates statistics from multiple files
1052+
config - prints a configuration for a model id
1053+
exportsample - produces a code to export a model
1054+
find - find node consuming or producing a result
1055+
lighten - makes an onnx model lighter by removing the weights,
1056+
print - prints the model on standard output
1057+
stats - produces statistics on a model
1058+
unlighten - restores an onnx model produces by the previous experiment
1059+
validate - validate a model
9781060
"""
9791061
),
9801062
)
@@ -983,6 +1065,7 @@ def get_main_parser() -> ArgumentParser:
9831065
choices=[
9841066
"agg",
9851067
"config",
1068+
"exportsample",
9861069
"find",
9871070
"lighten",
9881071
"print",
@@ -1005,6 +1088,7 @@ def main(argv: Optional[List[Any]] = None):
10051088
validate=_cmd_validate,
10061089
stats=_cmd_stats,
10071090
agg=_cmd_agg,
1091+
exportsample=_cmd_export_sample,
10081092
)
10091093

10101094
if argv is None:
@@ -1027,6 +1111,7 @@ def main(argv: Optional[List[Any]] = None):
10271111
validate=get_parser_validate,
10281112
stats=get_parser_stats,
10291113
agg=get_parser_agg,
1114+
exportsample=get_parser_validate,
10301115
)
10311116
cmd = argv[0]
10321117
if cmd not in parsers:

onnx_diagnostic/torch_models/code_sample.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str:
8080
elif v.dtype in (torch.float32, torch.float16, torch.bfloat16):
8181
code = f"{k}=torch.rand({shape}, dtype={v.dtype})"
8282
else:
83-
raise ValueError(f"Unexpeted dtype = {v.dtype} for k={k!r}")
83+
raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}")
8484
elif v.__class__.__name__ == "DynamicCache":
8585
obj = flatten_object(v)
8686
cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj]
8787
va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])]
88-
vas = ", ".join(va)
89-
code = f"{k}=make_dynamic_cache([{vas}])"
88+
va2 = ", ".join(va)
89+
code = f"{k}=make_dynamic_cache([{va2}])"
9090
else:
9191
raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
9292
codes.append(code)
@@ -146,7 +146,7 @@ def make_export_code(
146146

147147
imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches")
148148
if stop_if_static:
149-
patch_kwargs["patch_kwargs"] = stop_if_static
149+
patch_kwargs["stop_if_static"] = stop_if_static
150150
sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items())
151151
code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]]
152152
return "\n".join(imports), "\n".join(code)
@@ -331,7 +331,7 @@ def code_sample(
331331
f"inputs = {input_code}",
332332
exporter_code,
333333
]
334-
code = "\n".join(pieces)
334+
code = "\n".join(pieces) # type: ignore[arg-type]
335335
try:
336336
import black
337337
except ImportError:

onnx_diagnostic/torch_models/validate.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def validate_model(
350350
output_names: Optional[List[str]] = None,
351351
ort_logs: bool = False,
352352
quiet_input_sets: Optional[Set[str]] = None,
353-
sample_code: Optional[str] = None,
354353
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
355354
"""
356355
Validates a model.
@@ -407,8 +406,6 @@ def validate_model(
407406
:param ort_logs: increases onnxruntime verbosity when creating the session
408407
:param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
409408
even if quiet is False
410-
:param sample_code: if specified, the function generates a code
411-
which exports this model id without this package.
412409
:return: two dictionaries, one with some metrics,
413410
another one with whatever the function produces
414411

0 commit comments

Comments
 (0)