Skip to content

Commit 675a01e

Browse files
committed
export a method to onnx
1 parent 598d5ea commit 675a01e

File tree

3 files changed

+340
-3
lines changed

3 files changed

+340
-3
lines changed

_unittests/ut_export/test_api.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
has_transformers,
88
ignore_warnings,
99
requires_transformers,
10+
requires_experimental_experiment,
1011
)
1112
from onnx_diagnostic.helpers import max_diff
1213
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1314
from onnx_diagnostic.helpers.rt_helper import make_feeds
1415
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
1516
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1617
from onnx_diagnostic.torch_export_patches import torch_export_patches
17-
from onnx_diagnostic.export.api import to_onnx
18+
from onnx_diagnostic.export.api import to_onnx, method_to_onnx
1819

1920

2021
class TestValidate(ExtTestCase):
@@ -114,6 +115,136 @@ def test_tiny_llm_to_onnx(self):
114115

115116
self.clean_dump()
116117

118+
@requires_experimental_experiment("0.1")
119+
def test_method_to_onnx_args(self):
120+
class Model(torch.nn.Module):
121+
def forward(self, x, y):
122+
return x + y
123+
124+
filename = self.get_dump_file("test_method_to_onnx_args.onnx")
125+
inputs = [
126+
(torch.randn((5, 6)), torch.randn((1, 6))),
127+
(torch.randn((7, 7)), torch.randn((1, 7))),
128+
]
129+
model = Model()
130+
method_to_call = method_to_onnx(model, exporter="custom", filename=filename)
131+
expecteds = []
132+
for args in inputs:
133+
expecteds.append(method_to_call(*args))
134+
self.assertExists(filename)
135+
src = method_to_call._method_src
136+
self.assertIn("f(self, x, y):", src)
137+
self.assertIn("return self._call(x=x, y=y)", src)
138+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
139+
sess = self.check_ort(filename)
140+
input_names = [i.name for i in sess.get_inputs()]
141+
for expected, args in zip(expecteds, inputs):
142+
feeds = make_feeds(input_names, args, use_numpy=True)
143+
got = sess.run(None, feeds)
144+
self.assertEqualArray(expected, got[0])
145+
self.clean_dump()
146+
147+
@requires_experimental_experiment("0.1")
148+
def test_method_to_onnx_kwargs(self):
149+
class Model(torch.nn.Module):
150+
def forward(self, x=None, y=None):
151+
return x + y
152+
153+
filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx")
154+
inputs = [
155+
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))),
156+
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))),
157+
]
158+
model = Model()
159+
method_to_call = method_to_onnx(model, exporter="custom", filename=filename)
160+
expecteds = []
161+
for kwargs in inputs:
162+
expecteds.append(method_to_call(**kwargs))
163+
self.assertExists(filename)
164+
src = method_to_call._method_src
165+
self.assertIn("f(self, x=None, y=None):", src)
166+
self.assertIn("return self._call(x=x, y=y)", src)
167+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
168+
sess = self.check_ort(filename)
169+
input_names = [i.name for i in sess.get_inputs()]
170+
for expected, kwargs in zip(expecteds, inputs):
171+
feeds = make_feeds(input_names, kwargs, use_numpy=True)
172+
got = sess.run(None, feeds)
173+
self.assertEqualArray(expected, got[0])
174+
self.clean_dump()
175+
176+
@requires_experimental_experiment("0.1")
177+
def test_method_to_onnx_kwargs_patch(self):
178+
class Model(torch.nn.Module):
179+
def forward(self, x=None, y=None):
180+
return x + y
181+
182+
filename = self.get_dump_file("test_method_to_onnx_kwargs_patch.onnx")
183+
inputs = [
184+
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))),
185+
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))),
186+
]
187+
model = Model()
188+
method_to_call = method_to_onnx(
189+
model,
190+
exporter="custom",
191+
filename=filename,
192+
patch_kwargs=dict(patch_transformers=True),
193+
)
194+
expecteds = []
195+
for kwargs in inputs:
196+
expecteds.append(method_to_call(**kwargs))
197+
self.assertExists(filename)
198+
src = method_to_call._method_src
199+
self.assertIn("f(self, x=None, y=None):", src)
200+
self.assertIn("return self._call(x=x, y=y)", src)
201+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
202+
sess = self.check_ort(filename)
203+
input_names = [i.name for i in sess.get_inputs()]
204+
for expected, kwargs in zip(expecteds, inputs):
205+
feeds = make_feeds(input_names, kwargs, use_numpy=True)
206+
got = sess.run(None, feeds)
207+
self.assertEqualArray(expected, got[0])
208+
self.clean_dump()
209+
210+
@requires_experimental_experiment("0.1")
211+
@hide_stdout()
212+
def test_method_to_onnx_mixed(self):
213+
from experimental_experiment.torch_interpreter import ExportOptions
214+
215+
class Model(torch.nn.Module):
216+
def forward(self, x, y=None):
217+
return x + y
218+
219+
filename = self.get_dump_file("test_method_to_onnx_mixed.onnx")
220+
inputs = [
221+
((torch.randn((5, 6)),), dict(y=torch.randn((1, 6)))),
222+
((torch.randn((7, 7)),), dict(y=torch.randn((1, 7)))),
223+
]
224+
model = Model()
225+
method_to_call = method_to_onnx(
226+
model,
227+
exporter="custom",
228+
filename=filename,
229+
verbose=10,
230+
exporter_kwargs=dict(export_options=ExportOptions(backed_size_oblivious=False)),
231+
)
232+
expecteds = []
233+
for args, kwargs in inputs:
234+
expecteds.append(method_to_call(*args, **kwargs))
235+
self.assertExists(filename)
236+
src = method_to_call._method_src
237+
self.assertIn("f(self, x, y=None):", src)
238+
self.assertIn("return self._call(x=x, y=y)", src)
239+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
240+
sess = self.check_ort(filename)
241+
input_names = [i.name for i in sess.get_inputs()]
242+
for expected, (args, kwargs) in zip(expecteds, inputs):
243+
feeds = make_feeds(input_names, (args, kwargs), use_numpy=True)
244+
got = sess.run(None, feeds)
245+
self.assertEqualArray(expected, got[0])
246+
self.clean_dump()
247+
117248

118249
if __name__ == "__main__":
119250
unittest.main(verbosity=2)

onnx_diagnostic/export/api.py

Lines changed: 204 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import inspect
12
import os
2-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3+
import textwrap
4+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
35
import torch
6+
from .dynamic_shapes import ModelInputs
47
from .onnx_plug import EagerDirectReplacementWithOnnx
8+
from ..helpers import string_type
59

610

711
def get_main_dispatcher(
@@ -71,6 +75,7 @@ def to_onnx(
7175
inline: bool = True,
7276
) -> Any:
7377
"""
78+
Exports one model into ONNX.
7479
Common API for exporters. By default, the models are optimized to use the
7580
most efficient kernels implemented in :epkg:`onnxruntime`.
7681
@@ -127,8 +132,12 @@ def to_onnx(
127132
from experimental_experiment.xbuilder import OptimizationOptions
128133

129134
options = None
135+
export_options = None
130136
if exporter_kwargs is not None:
131137
options = exporter_kwargs.pop("options", None)
138+
export_options = exporter_kwargs.pop("export_options", None)
139+
if export_options is None:
140+
export_options = ExportOptions(save_ep=save_ep)
132141
if options is None and optimize:
133142
options = OptimizationOptions(
134143
patterns="default+onnxruntime" if optimizer_for_ort else "default"
@@ -151,7 +160,7 @@ def to_onnx(
151160
dynamic_shapes=dynamic_shapes,
152161
large_model=True,
153162
output_dynamic_shapes=output_dynamic_shapes,
154-
export_options=ExportOptions(save_ep=save_ep),
163+
export_options=export_options,
155164
options=options,
156165
inline=inline,
157166
dispatcher=main_dispatcher,
@@ -303,3 +312,196 @@ def to_onnx(
303312
return onx
304313

305314
raise ValueError(f"Unknown exporter={exporter!r}")
315+
316+
317+
class _WrapperToExportMethodToOnnx(torch.nn.Module):
318+
"""
319+
Wraps an existing models in order to spy on inputs.
320+
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
321+
"""
322+
323+
def __init__(
324+
self,
325+
mod: "torch.nn.Module",
326+
method_name: str = "forward",
327+
input_names: Optional[Sequence[str]] = None,
328+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
329+
verbose: int = 0,
330+
filename: Optional[str] = None,
331+
output_names: Optional[List[str]] = None,
332+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
333+
exporter: str = "onnx-dynamo",
334+
exporter_kwargs: Optional[Dict[str, Any]] = None,
335+
save_ep: Optional[str] = None,
336+
optimize: bool = True,
337+
optimizer_for_ort: bool = True,
338+
use_control_flow_dispatcher: bool = False,
339+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
340+
inline: bool = True,
341+
convert_after_n_calls: int = 2,
342+
patch_kwargs: Optional[Dict[str, Any]] = None,
343+
):
344+
super().__init__()
345+
self._model_to_call = mod
346+
self._method_name = method_name
347+
self._call = (
348+
self._model_to_call if method_name == "forward" else getattr(mod, method_name)
349+
)
350+
self._inputs = []
351+
self._convert_after_n_calls = convert_after_n_calls
352+
self._patch_kwargs = patch_kwargs
353+
self._method_src = None
354+
self.verbose = verbose
355+
self._to_onnx_kwargs = dict(
356+
input_names=input_names,
357+
target_opset=target_opset,
358+
verbose=verbose,
359+
filename=filename,
360+
output_names=output_names,
361+
output_dynamic_shapes=output_dynamic_shapes,
362+
exporter=exporter,
363+
exporter_kwargs=exporter_kwargs,
364+
save_ep=save_ep,
365+
optimize=optimize,
366+
optimizer_for_ort=optimizer_for_ort,
367+
use_control_flow_dispatcher=use_control_flow_dispatcher,
368+
onnx_plugs=onnx_plugs,
369+
inline=inline,
370+
)
371+
372+
def forward(self, *args, **kwargs):
373+
self._inputs.append((args, kwargs))
374+
if self.verbose:
375+
print(
376+
f"[method_to_onnx] input{len(self._inputs)}: "
377+
f"{string_type((args, kwargs), with_shape=True)}"
378+
)
379+
if len(self._inputs) >= self._convert_after_n_calls:
380+
self._convert_method_to_onnx()
381+
return self._call(*args, **kwargs)
382+
383+
def _convert_method_to_onnx(self):
384+
385+
def make_method(self):
386+
sig = inspect.signature(getattr(self._model_to_call, self._method_name))
387+
args = str(sig)[1:-1]
388+
calls_args = ", ".join(f"{p}={p}" for p in sig.parameters)
389+
src = textwrap.dedent(
390+
f"""
391+
def f(self, {args}):
392+
return self._call({calls_args})
393+
"""
394+
)
395+
self._method_src = src
396+
ns = {}
397+
exec(src, ns)
398+
return ns["f"]
399+
400+
class WrapWithExactSignature(torch.nn.Module):
401+
def __init__(self, parent):
402+
super().__init__()
403+
self._model_to_call = parent._model_to_call
404+
self._call = parent._call
405+
406+
forward = make_method(self)
407+
408+
compiled_model = WrapWithExactSignature(self)
409+
mi = ModelInputs(compiled_model, self._inputs)
410+
ds = mi.guess_dynamic_shapes()
411+
if self.verbose:
412+
print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
413+
a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
414+
if self.verbose:
415+
print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
416+
print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
417+
print(f"[method_to_onnx] dynamic_shapes={string_type(nds)}")
418+
if self._patch_kwargs is None:
419+
to_onnx(
420+
compiled_model,
421+
args=a,
422+
kwargs=kw,
423+
dynamic_shapes=nds[-1],
424+
**self._to_onnx_kwargs,
425+
)
426+
return
427+
from ..torch_export_patches import torch_export_patches
428+
429+
with torch_export_patches(**self._patch_kwargs):
430+
to_onnx(
431+
compiled_model,
432+
args=a,
433+
kwargs=kw,
434+
dynamic_shapes=nds[-1],
435+
**self._to_onnx_kwargs,
436+
)
437+
438+
439+
def method_to_onnx(
440+
mod: "torch.nn.Module",
441+
method_name: str = "forward",
442+
input_names: Optional[Sequence[str]] = None,
443+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
444+
verbose: int = 0,
445+
filename: Optional[str] = None,
446+
output_names: Optional[List[str]] = None,
447+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
448+
exporter: str = "onnx-dynamo",
449+
exporter_kwargs: Optional[Dict[str, Any]] = None,
450+
save_ep: Optional[str] = None,
451+
optimize: bool = True,
452+
optimizer_for_ort: bool = True,
453+
use_control_flow_dispatcher: bool = False,
454+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
455+
inline: bool = True,
456+
convert_after_n_calls: int = 2,
457+
patch_kwargs: Optional[Dict[str, Any]] = None,
458+
) -> Callable:
459+
"""
460+
Exports one method into ONNX for a module into ONNX.
461+
It returns a new method which must be called by the user
462+
at least twice with different values for the dynamic dimension
463+
between triggering the conversion into ONNX.
464+
465+
:param mod_meth: function to export into ONNX
466+
:param input_names: input names for the onnx model (optional)
467+
:param target_opset: opset to target, if not specified, each converter
468+
keeps its default value
469+
:param verbose: verbosity level
470+
:param filename: output filename, mandatory, the onnx model is saved on disk
471+
:param output_names: to change the output of the onnx model
472+
:param output_dynamic_shapes: to overwrite the dynamic shapes names
473+
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
474+
:param exporter_kwargs: additional parameters sent to the exporter
475+
:param save_ep: saves the exported program
476+
:param optimize: optimizes the model
477+
:param optimizer_for_ort: optimizes the model for onnxruntime
478+
:param use_control_flow_dispatcher: use the dispatcher created to supported
479+
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
480+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
481+
:param inline: inline local functions
482+
:param convert_after_n_calls: convets the model after this number of calls.
483+
:param patch_kwargs: patch arguments
484+
:return: the output of the selected exporter, usually a structure including
485+
an onnx model
486+
"""
487+
wrapped_model = _WrapperToExportMethodToOnnx(
488+
mod=mod,
489+
method_name=method_name,
490+
input_names=input_names,
491+
target_opset=target_opset,
492+
verbose=verbose,
493+
filename=filename,
494+
output_names=output_names,
495+
output_dynamic_shapes=output_dynamic_shapes,
496+
exporter=exporter,
497+
exporter_kwargs=exporter_kwargs,
498+
save_ep=save_ep,
499+
optimize=optimize,
500+
optimizer_for_ort=optimizer_for_ort,
501+
use_control_flow_dispatcher=use_control_flow_dispatcher,
502+
onnx_plugs=onnx_plugs,
503+
inline=inline,
504+
convert_after_n_calls=convert_after_n_calls,
505+
patch_kwargs=patch_kwargs,
506+
)
507+
return wrapped_model

0 commit comments

Comments
 (0)