Skip to content

Commit ca25dc6

Browse files
authored
Adds code to generate automated export sample for a model id (#270)
* add function to generate dummy code * changes * fix dtype * fix dynamo * fix issues * fixes * fix examples * status
1 parent 283b2cd commit ca25dc6

File tree

7 files changed

+581
-27
lines changed

7 files changed

+581
-27
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7+
* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
78
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
89
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
910
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.code_sample
3+
========================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.code_sample
6+
:members:
7+
:no-undoc-members:

_doc/api/torch_models/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ onnx_diagnostic.torch_models
55
:maxdepth: 1
66
:caption: submodules
77

8+
code_sample
89
hghub/index
910
llms
1011
validate
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import unittest
2+
import subprocess
3+
import sys
4+
import torch
5+
from onnx_diagnostic.ext_test_case import (
6+
ExtTestCase,
7+
hide_stdout,
8+
requires_torch,
9+
requires_experimental,
10+
requires_transformers,
11+
)
12+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
13+
from onnx_diagnostic.torch_models.code_sample import code_sample, make_code_for_inputs
14+
15+
16+
class TestCodeSample(ExtTestCase):
17+
@requires_transformers("4.53")
18+
@requires_torch("2.9")
19+
@requires_experimental()
20+
@hide_stdout()
21+
def test_code_sample_tiny_llm_custom(self):
22+
code = code_sample(
23+
"arnir0/Tiny-LLM",
24+
verbose=2,
25+
exporter="custom",
26+
patch=True,
27+
dump_folder="dump_test/validate_tiny_llm_custom",
28+
dtype="float16",
29+
device="cpu",
30+
optimization="default",
31+
)
32+
filename = self.get_dump_file("test_code_sample_tiny_llm_custom.py")
33+
with open(filename, "w") as f:
34+
f.write(code)
35+
cmds = [sys.executable, "-u", filename]
36+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
37+
res = p.communicate()
38+
_out, err = res
39+
st = err.decode("ascii", errors="ignore")
40+
self.assertNotIn("Traceback", st)
41+
42+
@requires_transformers("4.53")
43+
@requires_torch("2.9")
44+
@requires_experimental()
45+
@hide_stdout()
46+
def test_code_sample_tiny_llm_dynamo(self):
47+
code = code_sample(
48+
"arnir0/Tiny-LLM",
49+
verbose=2,
50+
exporter="onnx-dynamo",
51+
patch=True,
52+
dump_folder="dump_test/validate_tiny_llm_dynamo",
53+
dtype="float16",
54+
device="cpu",
55+
optimization="ir",
56+
)
57+
filename = self.get_dump_file("test_code_sample_tiny_llm_dynamo.py")
58+
with open(filename, "w") as f:
59+
f.write(code)
60+
cmds = [sys.executable, "-u", filename]
61+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
62+
res = p.communicate()
63+
_out, err = res
64+
st = err.decode("ascii", errors="ignore")
65+
self.assertNotIn("Traceback", st)
66+
67+
def test_make_code_for_inputs(self):
68+
values = [
69+
("dict(a=True)", dict(a=True)),
70+
("dict(a=1)", dict(a=1)),
71+
(
72+
"dict(a=torch.randint(3, size=(2,), dtype=torch.int64))",
73+
dict(a=torch.tensor([2, 3], dtype=torch.int64)),
74+
),
75+
(
76+
"dict(a=torch.rand((2,), dtype=torch.float16))",
77+
dict(a=torch.tensor([2, 3], dtype=torch.float16)),
78+
),
79+
]
80+
for res, inputs in values:
81+
self.assertEqual(res, make_code_for_inputs(inputs))
82+
83+
res = make_code_for_inputs(
84+
dict(
85+
cc=make_dynamic_cache(
86+
[(torch.randn(2, 2, 2, 2), torch.randn(2, 2, 2, 2)) for i in range(2)]
87+
)
88+
)
89+
)
90+
self.assertEqual(
91+
"dict(cc=make_dynamic_cache([(torch.rand((2, 2, 2, 2), "
92+
"dtype=torch.float32),torch.rand((2, 2, 2, 2), dtype=torch.float32)), "
93+
"(torch.rand((2, 2, 2, 2), dtype=torch.float32),"
94+
"torch.rand((2, 2, 2, 2), dtype=torch.float32))]))",
95+
res,
96+
)
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def get_parser_validate() -> ArgumentParser:
557557
"--quiet-input-sets",
558558
default="",
559559
help="Avoids raising an exception when an input sets does not work with "
560-
"the exported model, example: --quiet-input-sets=inputs,inputs22",
560+
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
561561
)
562562
return parser
563563

@@ -631,6 +631,94 @@ def _cmd_validate(argv: List[Any]):
631631
print(f":{k},{v};")
632632

633633

634+
def _cmd_export_sample(argv: List[Any]):
635+
from .helpers import string_type
636+
from .torch_models.validate import get_inputs_for_task, _make_folder_name
637+
from .torch_models.code_sample import code_sample
638+
from .tasks import supported_tasks
639+
640+
parser = get_parser_validate()
641+
args = parser.parse_args(argv[1:])
642+
if not args.task and not args.mid:
643+
print("-- list of supported tasks:")
644+
print("\n".join(supported_tasks()))
645+
elif not args.mid:
646+
data = get_inputs_for_task(args.task)
647+
if args.verbose:
648+
print(f"task: {args.task}")
649+
max_length = max(len(k) for k in data["inputs"]) + 1
650+
print("-- inputs")
651+
for k, v in data["inputs"].items():
652+
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
653+
print("-- dynamic_shapes")
654+
for k, v in data["dynamic_shapes"].items():
655+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
656+
else:
657+
# Let's skip any invalid combination if known to be unsupported
658+
if (
659+
"onnx" not in (args.export or "")
660+
and "custom" not in (args.export or "")
661+
and (args.opt or "")
662+
):
663+
print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}")
664+
return
665+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
666+
code = code_sample(
667+
model_id=args.mid,
668+
task=args.task,
669+
do_run=args.run,
670+
verbose=args.verbose,
671+
quiet=args.quiet,
672+
same_as_pretrained=args.same_as_trained,
673+
use_pretrained=args.trained,
674+
dtype=args.dtype,
675+
device=args.device,
676+
patch=patch_dict,
677+
rewrite=args.rewrite and patch_dict.get("patch", True),
678+
stop_if_static=args.stop_if_static,
679+
optimization=args.opt,
680+
exporter=args.export,
681+
dump_folder=args.dump_folder,
682+
drop_inputs=None if not args.drop else args.drop.split(","),
683+
input_options=args.iop,
684+
model_options=args.mop,
685+
subfolder=args.subfolder,
686+
opset=args.opset,
687+
runtime=args.runtime,
688+
output_names=(
689+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
690+
),
691+
)
692+
if args.dump_folder:
693+
os.makedirs(args.dump_folder, exist_ok=True)
694+
name = (
695+
_make_folder_name(
696+
model_id=args.model_id,
697+
exporter=args.exporter,
698+
optimization=args.optimization,
699+
dtype=args.dtype,
700+
device=args.device,
701+
subfolder=args.subfolder,
702+
opset=args.opset,
703+
drop_inputs=None if not args.drop else args.drop.split(","),
704+
same_as_pretrained=args.same_as_pretrained,
705+
use_pretrained=args.use_pretrained,
706+
task=args.task,
707+
).replace("/", "-")
708+
+ ".py"
709+
)
710+
fullname = os.path.join(args.dump_folder, name)
711+
if args.verbose:
712+
print(f"-- prints code in {fullname!r}")
713+
print("--")
714+
with open(fullname, "w") as f:
715+
f.write(code)
716+
if args.verbose:
717+
print("-- done")
718+
else:
719+
print(code)
720+
721+
634722
def get_parser_stats() -> ArgumentParser:
635723
parser = ArgumentParser(
636724
prog="stats",
@@ -960,14 +1048,15 @@ def get_main_parser() -> ArgumentParser:
9601048
Type 'python -m onnx_diagnostic <cmd> --help'
9611049
to get help for a specific command.
9621050
963-
agg - aggregates statistics from multiple files
964-
config - prints a configuration for a model id
965-
find - find node consuming or producing a result
966-
lighten - makes an onnx model lighter by removing the weights,
967-
print - prints the model on standard output
968-
stats - produces statistics on a model
969-
unlighten - restores an onnx model produces by the previous experiment
970-
validate - validate a model
1051+
agg - aggregates statistics from multiple files
1052+
config - prints a configuration for a model id
1053+
exportsample - produces a code to export a model
1054+
find - find node consuming or producing a result
1055+
lighten - makes an onnx model lighter by removing the weights,
1056+
print - prints the model on standard output
1057+
stats - produces statistics on a model
1058+
unlighten - restores an onnx model produces by the previous experiment
1059+
validate - validate a model
9711060
"""
9721061
),
9731062
)
@@ -976,6 +1065,7 @@ def get_main_parser() -> ArgumentParser:
9761065
choices=[
9771066
"agg",
9781067
"config",
1068+
"exportsample",
9791069
"find",
9801070
"lighten",
9811071
"print",
@@ -998,6 +1088,7 @@ def main(argv: Optional[List[Any]] = None):
9981088
validate=_cmd_validate,
9991089
stats=_cmd_stats,
10001090
agg=_cmd_agg,
1091+
exportsample=_cmd_export_sample,
10011092
)
10021093

10031094
if argv is None:
@@ -1020,6 +1111,7 @@ def main(argv: Optional[List[Any]] = None):
10201111
validate=get_parser_validate,
10211112
stats=get_parser_stats,
10221113
agg=get_parser_agg,
1114+
exportsample=get_parser_validate,
10231115
)
10241116
cmd = argv[0]
10251117
if cmd not in parsers:

0 commit comments

Comments
 (0)