1+ import inspect
12import os
2- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
3+ import textwrap
4+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
35import torch
6+ from .dynamic_shapes import ModelInputs
47from .onnx_plug import EagerDirectReplacementWithOnnx
8+ from ..helpers import string_type
59
610
711def get_main_dispatcher (
@@ -71,6 +75,7 @@ def to_onnx(
7175 inline : bool = True ,
7276) -> Any :
7377 """
78+ Exports one model into ONNX.
7479 Common API for exporters. By default, the models are optimized to use the
7580 most efficient kernels implemented in :epkg:`onnxruntime`.
7681
@@ -127,8 +132,12 @@ def to_onnx(
127132 from experimental_experiment .xbuilder import OptimizationOptions
128133
129134 options = None
135+ export_options = None
130136 if exporter_kwargs is not None :
131137 options = exporter_kwargs .pop ("options" , None )
138+ export_options = exporter_kwargs .pop ("export_options" , None )
139+ if export_options is None :
140+ export_options = ExportOptions (save_ep = save_ep )
132141 if options is None and optimize :
133142 options = OptimizationOptions (
134143 patterns = "default+onnxruntime" if optimizer_for_ort else "default"
@@ -151,7 +160,7 @@ def to_onnx(
151160 dynamic_shapes = dynamic_shapes ,
152161 large_model = True ,
153162 output_dynamic_shapes = output_dynamic_shapes ,
154- export_options = ExportOptions ( save_ep = save_ep ) ,
163+ export_options = export_options ,
155164 options = options ,
156165 inline = inline ,
157166 dispatcher = main_dispatcher ,
@@ -303,3 +312,196 @@ def to_onnx(
303312 return onx
304313
305314 raise ValueError (f"Unknown exporter={ exporter !r} " )
315+
316+
317+ class _WrapperToExportMethodToOnnx (torch .nn .Module ):
318+ """
319+ Wraps an existing models in order to spy on inputs.
320+ This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
321+ """
322+
323+ def __init__ (
324+ self ,
325+ mod : "torch.nn.Module" ,
326+ method_name : str = "forward" ,
327+ input_names : Optional [Sequence [str ]] = None ,
328+ target_opset : Optional [Union [int , Dict [str , int ]]] = None ,
329+ verbose : int = 0 ,
330+ filename : Optional [str ] = None ,
331+ output_names : Optional [List [str ]] = None ,
332+ output_dynamic_shapes : Optional [Union [Dict [str , Any ], Tuple [Any ]]] = None ,
333+ exporter : str = "onnx-dynamo" ,
334+ exporter_kwargs : Optional [Dict [str , Any ]] = None ,
335+ save_ep : Optional [str ] = None ,
336+ optimize : bool = True ,
337+ optimizer_for_ort : bool = True ,
338+ use_control_flow_dispatcher : bool = False ,
339+ onnx_plugs : Optional [List [EagerDirectReplacementWithOnnx ]] = None ,
340+ inline : bool = True ,
341+ convert_after_n_calls : int = 2 ,
342+ patch_kwargs : Optional [Dict [str , Any ]] = None ,
343+ ):
344+ super ().__init__ ()
345+ self ._model_to_call = mod
346+ self ._method_name = method_name
347+ self ._call = (
348+ self ._model_to_call if method_name == "forward" else getattr (mod , method_name )
349+ )
350+ self ._inputs = []
351+ self ._convert_after_n_calls = convert_after_n_calls
352+ self ._patch_kwargs = patch_kwargs
353+ self ._method_src = None
354+ self .verbose = verbose
355+ self ._to_onnx_kwargs = dict (
356+ input_names = input_names ,
357+ target_opset = target_opset ,
358+ verbose = verbose ,
359+ filename = filename ,
360+ output_names = output_names ,
361+ output_dynamic_shapes = output_dynamic_shapes ,
362+ exporter = exporter ,
363+ exporter_kwargs = exporter_kwargs ,
364+ save_ep = save_ep ,
365+ optimize = optimize ,
366+ optimizer_for_ort = optimizer_for_ort ,
367+ use_control_flow_dispatcher = use_control_flow_dispatcher ,
368+ onnx_plugs = onnx_plugs ,
369+ inline = inline ,
370+ )
371+
372+ def forward (self , * args , ** kwargs ):
373+ self ._inputs .append ((args , kwargs ))
374+ if self .verbose :
375+ print (
376+ f"[method_to_onnx] input{ len (self ._inputs )} : "
377+ f"{ string_type ((args , kwargs ), with_shape = True )} "
378+ )
379+ if len (self ._inputs ) >= self ._convert_after_n_calls :
380+ self ._convert_method_to_onnx ()
381+ return self ._call (* args , ** kwargs )
382+
383+ def _convert_method_to_onnx (self ):
384+
385+ def make_method (self ):
386+ sig = inspect .signature (getattr (self ._model_to_call , self ._method_name ))
387+ args = str (sig )[1 :- 1 ]
388+ calls_args = ", " .join (f"{ p } ={ p } " for p in sig .parameters )
389+ src = textwrap .dedent (
390+ f"""
391+ def f(self, { args } ):
392+ return self._call({ calls_args } )
393+ """
394+ )
395+ self ._method_src = src
396+ ns = {}
397+ exec (src , ns )
398+ return ns ["f" ]
399+
400+ class WrapWithExactSignature (torch .nn .Module ):
401+ def __init__ (self , parent ):
402+ super ().__init__ ()
403+ self ._model_to_call = parent ._model_to_call
404+ self ._call = parent ._call
405+
406+ forward = make_method (self )
407+
408+ compiled_model = WrapWithExactSignature (self )
409+ mi = ModelInputs (compiled_model , self ._inputs )
410+ ds = mi .guess_dynamic_shapes ()
411+ if self .verbose :
412+ print (f"[method_to_onnx] guess_dynamic_shapes={ string_type (ds )} " )
413+ a , kw , nds = mi .move_to_kwargs (* self ._inputs [- 1 ], ds )
414+ if self .verbose :
415+ print (f"[method_to_onnx] export args={ string_type (a , with_shape = True )} " )
416+ print (f"[method_to_onnx] export kwargs={ string_type (kw , with_shape = True )} " )
417+ print (f"[method_to_onnx] dynamic_shapes={ string_type (nds )} " )
418+ if self ._patch_kwargs is None :
419+ to_onnx (
420+ compiled_model ,
421+ args = a ,
422+ kwargs = kw ,
423+ dynamic_shapes = nds [- 1 ],
424+ ** self ._to_onnx_kwargs ,
425+ )
426+ return
427+ from ..torch_export_patches import torch_export_patches
428+
429+ with torch_export_patches (** self ._patch_kwargs ):
430+ to_onnx (
431+ compiled_model ,
432+ args = a ,
433+ kwargs = kw ,
434+ dynamic_shapes = nds [- 1 ],
435+ ** self ._to_onnx_kwargs ,
436+ )
437+
438+
439+ def method_to_onnx (
440+ mod : "torch.nn.Module" ,
441+ method_name : str = "forward" ,
442+ input_names : Optional [Sequence [str ]] = None ,
443+ target_opset : Optional [Union [int , Dict [str , int ]]] = None ,
444+ verbose : int = 0 ,
445+ filename : Optional [str ] = None ,
446+ output_names : Optional [List [str ]] = None ,
447+ output_dynamic_shapes : Optional [Union [Dict [str , Any ], Tuple [Any ]]] = None ,
448+ exporter : str = "onnx-dynamo" ,
449+ exporter_kwargs : Optional [Dict [str , Any ]] = None ,
450+ save_ep : Optional [str ] = None ,
451+ optimize : bool = True ,
452+ optimizer_for_ort : bool = True ,
453+ use_control_flow_dispatcher : bool = False ,
454+ onnx_plugs : Optional [List [EagerDirectReplacementWithOnnx ]] = None ,
455+ inline : bool = True ,
456+ convert_after_n_calls : int = 2 ,
457+ patch_kwargs : Optional [Dict [str , Any ]] = None ,
458+ ) -> Callable :
459+ """
460+ Exports one method into ONNX for a module into ONNX.
461+ It returns a new method which must be called by the user
462+ at least twice with different values for the dynamic dimension
463+ between triggering the conversion into ONNX.
464+
465+ :param mod_meth: function to export into ONNX
466+ :param input_names: input names for the onnx model (optional)
467+ :param target_opset: opset to target, if not specified, each converter
468+ keeps its default value
469+ :param verbose: verbosity level
470+ :param filename: output filename, mandatory, the onnx model is saved on disk
471+ :param output_names: to change the output of the onnx model
472+ :param output_dynamic_shapes: to overwrite the dynamic shapes names
473+ :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
474+ :param exporter_kwargs: additional parameters sent to the exporter
475+ :param save_ep: saves the exported program
476+ :param optimize: optimizes the model
477+ :param optimizer_for_ort: optimizes the model for onnxruntime
478+ :param use_control_flow_dispatcher: use the dispatcher created to supported
479+ custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
480+ :param onnx_plugs: the code was modified to replace some parts with onnx translation
481+ :param inline: inline local functions
482+ :param convert_after_n_calls: convets the model after this number of calls.
483+ :param patch_kwargs: patch arguments
484+ :return: the output of the selected exporter, usually a structure including
485+ an onnx model
486+ """
487+ wrapped_model = _WrapperToExportMethodToOnnx (
488+ mod = mod ,
489+ method_name = method_name ,
490+ input_names = input_names ,
491+ target_opset = target_opset ,
492+ verbose = verbose ,
493+ filename = filename ,
494+ output_names = output_names ,
495+ output_dynamic_shapes = output_dynamic_shapes ,
496+ exporter = exporter ,
497+ exporter_kwargs = exporter_kwargs ,
498+ save_ep = save_ep ,
499+ optimize = optimize ,
500+ optimizer_for_ort = optimizer_for_ort ,
501+ use_control_flow_dispatcher = use_control_flow_dispatcher ,
502+ onnx_plugs = onnx_plugs ,
503+ inline = inline ,
504+ convert_after_n_calls = convert_after_n_calls ,
505+ patch_kwargs = patch_kwargs ,
506+ )
507+ return wrapped_model
0 commit comments