Skip to content

Commit 039243f

Browse files
committed
fix
1 parent cd4889f commit 039243f

File tree

5 files changed

+105
-15
lines changed

5 files changed

+105
-15
lines changed

_doc/examples/plot_export_tiny_llm_method_generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def generate_text(
6868
patch_kwargs=dict(patch_transformers=True),
6969
verbose=1,
7070
convert_after_n_calls=3,
71+
skip_kwargs_names={"kwargs", "use_cache", "return_dict"},
7172
)
7273

7374
# %%

_unittests/ut_export/test_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def forward(self, x, y):
134134
self.assertExists(filename)
135135
src = method_to_call._method_src
136136
self.assertIn("f(self, x, y):", src)
137-
self.assertIn("return self._call(x=x, y=y)", src)
137+
self.assertIn("return self._method_call(x=x, y=y)", src)
138138
self.assertEqual(len(list(method_to_call.named_modules())), 2)
139139
sess = self.check_ort(filename)
140140
input_names = [i.name for i in sess.get_inputs()]
@@ -163,7 +163,7 @@ def forward(self, x=None, y=None):
163163
self.assertExists(filename)
164164
src = method_to_call._method_src
165165
self.assertIn("f(self, x=None, y=None):", src)
166-
self.assertIn("return self._call(x=x, y=y)", src)
166+
self.assertIn("return self._method_call(x=x, y=y)", src)
167167
self.assertEqual(len(list(method_to_call.named_modules())), 2)
168168
sess = self.check_ort(filename)
169169
input_names = [i.name for i in sess.get_inputs()]
@@ -197,7 +197,7 @@ def forward(self, x=None, y=None):
197197
self.assertExists(filename)
198198
src = method_to_call._method_src
199199
self.assertIn("f(self, x=None, y=None):", src)
200-
self.assertIn("return self._call(x=x, y=y)", src)
200+
self.assertIn("return self._method_call(x=x, y=y)", src)
201201
self.assertEqual(len(list(method_to_call.named_modules())), 2)
202202
sess = self.check_ort(filename)
203203
input_names = [i.name for i in sess.get_inputs()]
@@ -235,7 +235,7 @@ def forward(self, x, y=None):
235235
self.assertExists(filename)
236236
src = method_to_call._method_src
237237
self.assertIn("f(self, x, y=None):", src)
238-
self.assertIn("return self._call(x=x, y=y)", src)
238+
self.assertIn("return self._method_call(x=x, y=y)", src)
239239
self.assertEqual(len(list(method_to_call.named_modules())), 2)
240240
sess = self.check_ort(filename)
241241
input_names = [i.name for i in sess.get_inputs()]

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,31 @@ def test_invalid_dimensions_for_export(self):
937937
backed_size_oblivious = cpl.invalid_dimensions_for_export()
938938
self.assertFalse(backed_size_oblivious)
939939

940+
def test_guess_dynamic_shapes_missing(self):
941+
class Model(torch.nn.Module):
942+
def forward(self, x, y=None):
943+
if y is None:
944+
return x.abs()
945+
return x.abs() + y
946+
947+
model = Model()
948+
x = torch.randn((5, 6))
949+
y = model(x=x)
950+
self.assertNotEmpty(y)
951+
952+
inputs = [
953+
(tuple(), {"x": x}),
954+
(tuple(), {"x": torch.randn((6, 6)), "y": torch.randn((6, 6))}),
955+
(tuple(), {"x": torch.randn((7, 6)), "y": torch.randn((7, 6))}),
956+
]
957+
958+
mi = ModelInputs(model, inputs)
959+
ds = mi.guess_dynamic_shapes()
960+
DYN = torch.export.Dim.DYNAMIC
961+
self.assertEqual(ds, ((), {"x": {0: DYN}, "y": {0: DYN}}))
962+
_a, _kw, ds = mi.move_to_kwargs(*mi.inputs[-1], ds)
963+
self.assertEqual(ds, (tuple(), {"x": {0: DYN}, "y": {0: DYN}}))
964+
940965

941966
if __name__ == "__main__":
942967
unittest.main(verbosity=2)

onnx_diagnostic/export/api.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import os
33
import textwrap
4-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
55
import torch
66
from .dynamic_shapes import ModelInputs
77
from .onnx_plug import EagerDirectReplacementWithOnnx
@@ -340,6 +340,7 @@ def __init__(
340340
inline: bool = True,
341341
convert_after_n_calls: int = 2,
342342
patch_kwargs: Optional[Dict[str, Any]] = None,
343+
skip_kwargs_names: Optional[Set[str]] = None,
343344
):
344345
super().__init__()
345346
self._model_to_call = mod
@@ -354,6 +355,7 @@ def __init__(
354355
self._patch_kwargs = patch_kwargs
355356
self._method_src = None
356357
self.verbose = verbose
358+
self.skip_kwargs_names = skip_kwargs_names
357359
self._to_onnx_kwargs = dict(
358360
input_names=input_names,
359361
target_opset=target_opset,
@@ -370,6 +372,7 @@ def __init__(
370372
onnx_plugs=onnx_plugs,
371373
inline=inline,
372374
)
375+
self._export_done = False
373376

374377
def __str__(self) -> str:
375378
return self.__repr__()
@@ -381,14 +384,28 @@ def __repr__(self) -> str:
381384
)
382385

383386
def forward(self, *args, **kwargs):
384-
self._inputs.append((args, kwargs))
385-
if self.verbose:
386-
print(
387-
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
388-
f"{string_type((args, kwargs), with_shape=True)}"
387+
if not self._export_done:
388+
self._inputs.append(
389+
(
390+
args,
391+
(
392+
kwargs
393+
if not kwargs or not self.skip_kwargs_names
394+
else {
395+
k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names
396+
}
397+
),
398+
)
389399
)
390-
if len(self._inputs) >= self._convert_after_n_calls:
391-
self._convert_method_to_onnx()
400+
if self.verbose:
401+
print(
402+
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
403+
f"{string_type(self._inputs[-1], with_shape=True)}"
404+
)
405+
if len(self._inputs) >= self._convert_after_n_calls:
406+
self._convert_method_to_onnx()
407+
del self._inputs[:]
408+
self._export_done = True
392409
return self._method_call(*args, **kwargs)
393410

394411
def _convert_method_to_onnx(self):
@@ -473,6 +490,7 @@ def method_to_onnx(
473490
inline: bool = True,
474491
convert_after_n_calls: int = 2,
475492
patch_kwargs: Optional[Dict[str, Any]] = None,
493+
skip_kwargs_names: Optional[Set[str]] = None,
476494
) -> Callable:
477495
"""
478496
Exports one method into ONNX for a module into ONNX.
@@ -499,8 +517,12 @@ def method_to_onnx(
499517
:param inline: inline local functions
500518
:param convert_after_n_calls: converts the model after this number of calls.
501519
:param patch_kwargs: patch arguments
520+
:param skip_kwargs_names: use default values for these parameters part of
521+
the signature of the method to export
502522
:return: the output of the selected exporter, usually a structure including
503523
an onnx model
524+
525+
See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
504526
"""
505527
wrapped_model = _WrapperToExportMethodToOnnx(
506528
mod=mod,
@@ -521,5 +543,6 @@ def method_to_onnx(
521543
inline=inline,
522544
convert_after_n_calls=convert_after_n_calls,
523545
patch_kwargs=patch_kwargs,
546+
skip_kwargs_names=skip_kwargs_names,
524547
)
525548
return wrapped_model

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,19 @@ def _generic_walker_step(
352352
else None
353353
)
354354
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
355+
if set(inputs) != set(ds):
356+
not_in_ds = {k for k in inputs if k not in ds}
357+
not_in_inputs = {k for k in ds if k not in inputs}
358+
assert not_in_inputs == {"kwargs"} and set(ds["kwargs"]) == not_in_ds, (
359+
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
360+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}, "
361+
f"not_in_ds={not_in_ds}, not_in_inputs={not_in_inputs}"
362+
)
363+
# Tweak...
364+
kws = ds["kwargs"]
365+
del ds["kwargs"]
366+
ds.update(kws)
367+
355368
assert set(inputs) == set(ds), (
356369
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
357370
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
@@ -366,13 +379,15 @@ def _generic_walker_step(
366379
return dvalue if dvalue else None
367380

368381
# A custom class.
369-
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
382+
assert inputs is None or inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
370383
f"Class {inputs.__class__.__name__!r} was not registered using "
371384
f"torch.utils._pytree.register_pytree_node, it is not possible to "
372385
f"map this class with the given dynamic shapes."
373386
)
374387
if flatten_unflatten:
375388
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
389+
if isinstance(flatunflat, (list, tuple, dict)) and len(flatunflat) == 0:
390+
return flatunflat
376391
res = cls._generic_walker_step(
377392
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
378393
)
@@ -667,6 +682,9 @@ def __init__(
667682
if self.signature
668683
else None
669684
)
685+
self.forward_parameters_kinds = {
686+
p.name: p.kind for p in self.signature.parameters.values()
687+
}
670688
self.forward_ordered_parameter_names = (
671689
list(self.signature.parameters) if self.signature else None
672690
)
@@ -973,7 +991,13 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES
973991
len(s1) == 1
974992
), f"Different numbers of positional arguments {s1} for {self.full_name}"
975993
s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
976-
assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
994+
assert len(s2) > 0, f"empty {s2} for {self.full_name}"
995+
if len(s2) > 1:
996+
# We need to keep the largest set of inputs, the one including all the others.
997+
sum_s2 = set()
998+
for s in s2:
999+
sum_s2 |= set(s)
1000+
s2 = {tuple(sum_s2)}
9771001
args = []
9781002
kwargs = {}
9791003
for i in range(s1.pop()):
@@ -993,7 +1017,7 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES
9931017
f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
9941018
)
9951019

996-
objs = [_[1][name] for _ in self.inputs]
1020+
objs = [_[1][name] for _ in self.inputs if name in _[1]]
9971021
kwargs[name] = self.guess_dynamic_shape_object(
9981022
*objs,
9991023
auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
@@ -1049,6 +1073,23 @@ def move_to_kwargs(
10491073
_kw_dyn = kw_dyn
10501074
kw_dyn = {}
10511075
for name in self.forward_ordered_parameter_names:
1076+
if (
1077+
self.forward_parameters_kinds[name] == inspect.Parameter.VAR_KEYWORD
1078+
and name not in _kwargs
1079+
and name in _kw_dyn
1080+
):
1081+
f = _kw_dyn[name]
1082+
assert isinstance(
1083+
f, dict
1084+
), f"Unexpected type for name={name!r}, _kw_dyn={_kw_dyn}"
1085+
for _k, _w in f.items():
1086+
assert (
1087+
_k in _kwargs
1088+
), f"Parameter {_k!r} not in found in kwargs: {set(_kwargs)}"
1089+
kwargs[_k] = _kwargs[_k]
1090+
kw_dyn[_k] = f[_k]
1091+
continue
1092+
10521093
if name in _kwargs:
10531094
kwargs[name] = _kwargs[name]
10541095
if name in _kw_dyn:

0 commit comments

Comments
 (0)