Skip to content

Commit ac301d1

Browse files
authored
Supports ModelBuilder (#111)
* Supports ModelBuilder * pymp * t * mb * mypy * fix modelbuilder inputs * fix issues * test' * page * publish
1 parent 66f77c2 commit ac301d1

File tree

16 files changed

+685
-21
lines changed

16 files changed

+685
-21
lines changed

.github/workflows/python-publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ jobs:
2121
runs-on: ubuntu-latest
2222

2323
steps:
24-
- uses: actions/checkout@v3
24+
- uses: actions/checkout@v4
2525
- name: Set up Python
26-
uses: actions/setup-python@v3
26+
uses: actions/setup-python@v5
2727
with:
2828
python-version: '3.x'
2929

CHANGELOGS.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ Change Logs
44
0.6.0
55
+++++
66

7-
* :pr:`108`: first version of an algorithm rendering small onnx graph in ascii,
8-
patch for ``torch.vmap``
7+
* :pr:`111`: support ModelBuilder with command line validatz
8+
* :pr:`108`, :pr:`109`, :pr:`110`: first version of an algorithm rendering
9+
small onnx graph in ascii, patch for ``torch.vmap``
910

1011
0.5.0
1112
+++++

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ onnx_diagnostic.helpers
1414
helper
1515
memory_peak
1616
mini_onnx_builder
17+
model_builder_helper
1718
onnx_helper
1819
ort_session
1920
rt_helper
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.model_builder_helper
3+
============================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.model_builder_helper
6+
:members:
7+
:no-undoc-members:

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def linkcode_resolve(domain, info):
116116
("py:class", "Argument"),
117117
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
118118
("py:class", "ModelProto"),
119+
("py:class", "Model"),
119120
("py:class", "Module"),
120121
("py:class", "np.ndarray"),
121122
("py:class", "onnxscript.ir.Tuple"),
@@ -211,6 +212,7 @@ def linkcode_resolve(domain, info):
211212
"huggingface_hub": "https://github.com/huggingface/huggingface_hub",
212213
"Linux": "https://www.linux.org/",
213214
"ml_dtypes": "https://github.com/jax-ml/ml_dtypes",
215+
"ModelBuilder": "https://onnxruntime.ai/docs/genai/howto/build-model.html",
214216
"monai": "https://monai.io/",
215217
"numpy": "https://numpy.org/",
216218
"onnx": "https://onnx.ai/onnx/",

_doc/status/exported_program_dynamic.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ to the original model.
4343
print()
4444
print("::")
4545
print()
46-
print(textwrap.indent(textwrap.dedent(inspect.getsource(cls_model.forward)), " "))
46+
src = inspect.getsource(cls_model.forward)
47+
if src:
48+
print(textwrap.indent(textwrap.dedent(src), " "))
49+
else:
50+
print(" # code is missing")
4751
print()
4852
for exporter in (
4953
"export-strict",
@@ -74,7 +78,11 @@ to the original model.
7478
print()
7579
print("::")
7680
print()
77-
print(textwrap.indent(str(res["error"]), " "))
81+
err = str(res["error"])
82+
if err:
83+
print(textwrap.indent(err, " "))
84+
else:
85+
print(" # no error found for the failure")
7886
print()
7987
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))
8088

_doc/status/patches_coverage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for transformers.
2828
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p
2929

3030
for name, cls in p.__dict__.items():
31-
if name.startswith("patched_"):
31+
if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
3232
print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")
3333

3434
Half Automated Rewrites for Control Flows

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
77
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
88
from onnx_diagnostic.torch_export_patches import torch_export_patches
9+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
910

1011

1112
class TestDynamicShapes(ExtTestCase):
@@ -512,7 +513,6 @@ def forward(self, cache, z):
512513
mi = ModelInputs(Model(), inputs)
513514
self.assertIn("DynamicCache", string_type(mi.inputs, with_shape=True))
514515
ds = mi.guess_dynamic_shapes(auto="dim")
515-
print(ds)
516516
self.assertEqual(
517517
ds,
518518
(
@@ -845,6 +845,23 @@ def test_dynamic_cache_replace_by_string(self):
845845
as_string,
846846
)
847847

848+
@requires_transformers("4.51")
849+
def test_unbatch_inputs(self):
850+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
851+
cpl = CoupleInputsDynamicShapes(
852+
None, data["inputs"], dynamic_shapes=data["dynamic_shapes"]
853+
)
854+
new_dims = cpl.change_dynamic_dimensions(
855+
desired_values=dict(batch=1), only_desired=True
856+
)
857+
s = self.string_type(new_dims, with_shape=True)
858+
self.assertEqual(
859+
"dict(input_ids:T7s1x3,attention_mask:T7s1x33,position_ids:T7s1x3,"
860+
"past_key_values:DynamicCache("
861+
"key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))",
862+
s,
863+
)
864+
848865

849866
if __name__ == "__main__":
850867
unittest.main(verbosity=2)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
import unittest
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
requires_torch,
6+
requires_transformers,
7+
hide_stdout,
8+
)
9+
from onnx_diagnostic.helpers.model_builder_helper import (
10+
download_model_builder_to_cache,
11+
import_model_builder,
12+
create_model_builder,
13+
save_model_builder,
14+
)
15+
from onnx_diagnostic.torch_models.hghub import (
16+
get_untrained_model_with_inputs,
17+
)
18+
from onnx_diagnostic.helpers.rt_helper import make_feeds
19+
20+
21+
class TestModelBuilderHelper(ExtTestCase):
22+
# This is to limit impact on CI.
23+
@requires_transformers("4.52")
24+
@requires_torch("2.7.99")
25+
def test_download_model_builder(self):
26+
path = download_model_builder_to_cache()
27+
self.assertExists(path)
28+
builder = import_model_builder()
29+
self.assertHasAttr(builder, "create_model")
30+
31+
# This is to limit impact on CI.
32+
@requires_transformers("4.52")
33+
@requires_torch("2.7.99")
34+
@hide_stdout()
35+
def test_model_builder_id(self):
36+
# clear&&python ~/.cache/onnx-diagnostic/builder.py
37+
# --model arnir0/Tiny-LLM -p fp16 -c dump_cache -e cpu -o dump_model
38+
folder = self.get_dump_folder("test_model_builder_id")
39+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
40+
onnx_model = create_model_builder(
41+
data["configuration"],
42+
data["model"],
43+
precision="fp32",
44+
execution_provider="cpu",
45+
cache_dir=folder,
46+
verbose=1,
47+
)
48+
self.assertGreater(len(onnx_model.nodes), 5)
49+
50+
proto = save_model_builder(onnx_model, verbose=1)
51+
import onnxruntime
52+
53+
onnxruntime.InferenceSession(
54+
proto.SerializeToString(), providers=["CPUExecutionProvider"]
55+
)
56+
57+
# We need to start again.
58+
onnx_model = create_model_builder(
59+
data["configuration"],
60+
data["model"],
61+
precision="fp32",
62+
execution_provider="cpu",
63+
cache_dir=folder,
64+
verbose=1,
65+
)
66+
save_model_builder(onnx_model, folder, verbose=1)
67+
model_name = os.path.join(folder, "model.onnx")
68+
self.assertExists(model_name)
69+
70+
feeds = make_feeds(proto, data["inputs"], use_numpy=True)
71+
expected = data["model"](**data["inputs"])
72+
73+
sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"])
74+
try:
75+
got = sess.run(None, feeds)
76+
except onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument as e:
77+
if "batch_size must be 1 when sequence_length > 1" in str(e):
78+
raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1")
79+
self.assertEqualAny(expected, got)
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
requires_torch,
1010
requires_experimental,
1111
requires_onnxscript,
12+
requires_transformers,
1213
)
1314
from onnx_diagnostic.torch_models.test_helper import (
1415
get_inputs_for_task,
@@ -184,6 +185,25 @@ def test_filter_inputs(self):
184185
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"], model=["a", "b"])
185186
self.assertEqual((ni, nd), (((None,), {"b": 4}), {"b": 30}))
186187

188+
@requires_torch("2.7")
189+
@hide_stdout()
190+
@ignore_warnings(FutureWarning)
191+
@requires_transformers("4.51")
192+
def test_validate_model_modelbuilder(self):
193+
mid = "arnir0/Tiny-LLM"
194+
summary, data = validate_model(
195+
mid,
196+
do_run=True,
197+
verbose=10,
198+
exporter="modelbuilder",
199+
dump_folder="dump_test_validate_model_onnx_dynamo",
200+
)
201+
self.assertIsInstance(summary, dict)
202+
self.assertIsInstance(data, dict)
203+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
204+
onnx_filename = data["onnx_filename"]
205+
self.assertExists(onnx_filename)
206+
187207

188208
if __name__ == "__main__":
189209
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)