Skip to content

Commit 46b9337

Browse files
committed
fix modelbuilder inputs
1 parent 6b61f1c commit 46b9337

File tree

7 files changed

+77
-6
lines changed

7 files changed

+77
-6
lines changed

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ onnx_diagnostic.helpers
1414
helper
1515
memory_peak
1616
mini_onnx_builder
17+
model_builder_helper
1718
onnx_helper
1819
ort_session
1920
rt_helper
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.model_builder_helper
3+
============================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.model_builder_helper
6+
:members:
7+
:no-undoc-members:

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def linkcode_resolve(domain, info):
211211
"huggingface_hub": "https://github.com/huggingface/huggingface_hub",
212212
"Linux": "https://www.linux.org/",
213213
"ml_dtypes": "https://github.com/jax-ml/ml_dtypes",
214+
"ModelBuilder": "https://onnxruntime.ai/docs/genai/howto/build-model.html",
214215
"monai": "https://monai.io/",
215216
"numpy": "https://numpy.org/",
216217
"onnx": "https://onnx.ai/onnx/",

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
77
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
88
from onnx_diagnostic.torch_export_patches import torch_export_patches
9+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
910

1011

1112
class TestDynamicShapes(ExtTestCase):
@@ -512,7 +513,6 @@ def forward(self, cache, z):
512513
mi = ModelInputs(Model(), inputs)
513514
self.assertIn("DynamicCache", string_type(mi.inputs, with_shape=True))
514515
ds = mi.guess_dynamic_shapes(auto="dim")
515-
print(ds)
516516
self.assertEqual(
517517
ds,
518518
(
@@ -845,6 +845,22 @@ def test_dynamic_cache_replace_by_string(self):
845845
as_string,
846846
)
847847

848+
def test_unbatch_inputs(self):
849+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
850+
cpl = CoupleInputsDynamicShapes(
851+
None, data["inputs"], dynamic_shapes=data["dynamic_shapes"]
852+
)
853+
new_dims = cpl.change_dynamic_dimensions(
854+
desired_values=dict(batch=1), only_desired=True
855+
)
856+
s = self.string_type(new_dims, with_shape=True)
857+
self.assertEqual(
858+
"dict(input_ids:T7s1x3,attention_mask:T7s1x33,position_ids:T7s1x3,"
859+
"past_key_values:DynamicCache("
860+
"key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))",
861+
s,
862+
)
863+
848864

849865
if __name__ == "__main__":
850866
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,24 @@ def test_filter_inputs(self):
184184
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"], model=["a", "b"])
185185
self.assertEqual((ni, nd), (((None,), {"b": 4}), {"b": 30}))
186186

187+
@requires_torch("2.7")
188+
@hide_stdout()
189+
@ignore_warnings(FutureWarning)
190+
def test_validate_model_modelbuilder(self):
191+
mid = "arnir0/Tiny-LLM"
192+
summary, data = validate_model(
193+
mid,
194+
do_run=True,
195+
verbose=10,
196+
exporter="modelbuilder",
197+
dump_folder="dump_test_validate_model_onnx_dynamo",
198+
)
199+
self.assertIsInstance(summary, dict)
200+
self.assertIsInstance(data, dict)
201+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
202+
onnx_filename = data["onnx_filename"]
203+
self.assertExists(onnx_filename)
204+
187205

188206
if __name__ == "__main__":
189207
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,9 @@ def _generic_walker_step(
379379
return torch.utils._pytree.tree_unflatten(res, spec)
380380

381381
class ChangeDimensionProcessor:
382-
def __init__(self, desired_values):
382+
def __init__(self, desired_values, only_desired):
383383
self.mapping = desired_values or {}
384+
self.only_desired = only_desired
384385

385386
def _build_new_shape(
386387
self, shape: Tuple[int, ...], ds: Dict[int, Any]
@@ -397,14 +398,16 @@ def _build_new_shape(
397398
torch.export.dynamic_shapes._Dim,
398399
),
399400
):
400-
d = str(ds[i])
401+
d = ds[i].__name__
401402
elif not isinstance(ds[i], int):
402403
raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
403404
if d in self.mapping:
404405
new_dim = self.mapping[d]
405-
else:
406+
elif not self.only_desired:
406407
new_dim = shape[i] + 1
407408
self.mapping[d] = new_dim
409+
else:
410+
new_dim = shape[i]
408411
new_shape[i] = new_dim
409412
return tuple(new_shape)
410413

@@ -447,7 +450,10 @@ def __call__(self, inputs, ds):
447450
return self._build_new_tensor(inputs, new_shape)
448451

449452
def change_dynamic_dimensions(
450-
self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
453+
self,
454+
desired_values: Optional[Dict[str, int]] = None,
455+
args_kwargs: bool = False,
456+
only_desired: bool = False,
451457
):
452458
"""
453459
A model exported with dynamic shapes is not necessarily dynamic
@@ -460,6 +466,8 @@ def change_dynamic_dimensions(
460466
461467
:param desired_values: to fixed named dimension to have the desired value
462468
:param args_kwargs: return both args, kwargs even if empty
469+
:param only_desired: if True, only change the dimension specified in
470+
``desired_values``
463471
:return: new inputs
464472
465473
Example:
@@ -483,7 +491,8 @@ def change_dynamic_dimensions(
483491
print("-after:", string_type(new_kwargs, with_shape=True))
484492
"""
485493
return self._generic_walker(
486-
self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
494+
self.ChangeDimensionProcessor(desired_values, only_desired=only_desired),
495+
args_kwargs=args_kwargs,
487496
)
488497

489498

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,25 @@ def validate_model(
345345
)
346346
),
347347
)
348+
349+
if exporter == "modelbuilder":
350+
# Models used with ModelBuilder do not like batch size > 1.
351+
# Let's change that.
352+
for k in ["inputs", "inputs2"]:
353+
if k not in data:
354+
continue
355+
if verbose:
356+
print(f"[validate_model] set batch=1 for data[{k!r}]")
357+
print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
358+
cpl = CoupleInputsDynamicShapes(
359+
None, data[k], dynamic_shapes=data["dynamic_shapes"]
360+
)
361+
data[k] = cpl.change_dynamic_dimensions(
362+
desired_values=dict(batch=1), only_desired=True
363+
)
364+
if verbose:
365+
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
366+
348367
data["input_options"] = iop
349368
data["model_options"] = mop
350369
data["model_dump_folder"] = dump_folder

0 commit comments

Comments
 (0)