11from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
22import torch
3+ from .onnx_plug import EagerDirectReplacementWithOnnx
34
45
56def to_onnx (
@@ -18,6 +19,8 @@ def to_onnx(
1819 save_ep : Optional [str ] = None ,
1920 optimize : bool = True ,
2021 use_control_flow_dispatcher : bool = False ,
22+ onnx_plugs : Optional [List [EagerDirectReplacementWithOnnx ]] = None ,
23+ inline : bool = True ,
2124) -> Any :
2225 """
2326 Common API for exporters. By default, the models are optimized to use the
@@ -41,6 +44,8 @@ def to_onnx(
4144 :param optimize: optimizes the model
4245 :param use_control_flow_dispatcher: use the dispatcher created to supported
4346 custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
47+ :param onnx_plugs: the code was modified to replace some parts with onnx translation
48+ :param inline: inline local functions
4449 :return: the output of the selected exporter, usually a structure including
4550 an onnx model
4651
@@ -55,24 +60,73 @@ def to_onnx(
5560 exporter=exporter,
5661 filename=filename,
5762 )
63+
64+ Some examples using control flows are available in
65+ :func:`onnx_diagnostic.export.control_flow.loop_for` or
66+ :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
5867 """
68+ if exporter_kwargs and "inline" in exporter_kwargs :
69+ assert (
70+ inline == exporter_kwargs ["inline" ]
71+ ), f"Mismatch between inline={ inline } and exporter_kwargs={ exporter_kwargs } "
72+ exporter_kwargs .pop ("inline" )
5973 if exporter == "custom" :
6074 from experimental_experiment .torch_interpreter import (
6175 to_onnx as _to_onnx ,
6276 ExportOptions ,
6377 )
6478 from experimental_experiment .xbuilder import OptimizationOptions
6579
66- if use_control_flow_dispatcher :
67- from .control_flow import create_global_dispatcher
68-
69- dispatcher = create_global_dispatcher ()
70-
7180 options = None
7281 if exporter_kwargs is not None :
7382 options = exporter_kwargs .pop ("options" , None )
7483 if options is None :
7584 options = OptimizationOptions (patterns = "default+onnxruntime" )
85+ if onnx_plugs or use_control_flow_dispatcher :
86+ from experimental_experiment .torch_interpreter import Dispatcher
87+
88+ if use_control_flow_dispatcher :
89+ from .control_flow import create_global_dispatcher
90+
91+ control_flow_dispatcher = create_global_dispatcher ()
92+ else :
93+ control_flow_dispatcher = None
94+
95+ class MainDispatcher (Dispatcher ):
96+ def __init__ (self , previous_dispatcher = None ):
97+ super ().__init__ ({})
98+ self .previous_dispatcher = previous_dispatcher
99+
100+ @property
101+ def supported (self ):
102+ if self .previous_dispatcher :
103+ return (
104+ set (self .registered_functions ) | self .previous_dispatcher .supported
105+ )
106+ return set (self .registered_functions )
107+
108+ def find_function (self , name : Any ):
109+ if self .previous_dispatcher :
110+ find = self .previous_dispatcher .find_function (name )
111+ if find :
112+ return find
113+ return Dispatcher .find_function (self , name )
114+
115+ def find_method (self , name : Any ):
116+ if self .previous_dispatcher :
117+ find = self .previous_dispatcher .find_method (name )
118+ if find :
119+ return find
120+ return Dispatcher .find_method (self , name )
121+
122+ main_dispatcher = MainDispatcher (control_flow_dispatcher )
123+ if onnx_plugs :
124+ for plug in onnx_plugs :
125+ main_dispatcher .registered_functions [plug .target_name ] = (
126+ plug .custom_converter ()
127+ )
128+ else :
129+ main_dispatcher = None
76130
77131 return _to_onnx (
78132 mod ,
@@ -88,8 +142,9 @@ def to_onnx(
88142 output_dynamic_shapes = output_dynamic_shapes ,
89143 export_options = ExportOptions (save_ep = save_ep ),
90144 options = options ,
145+ inline = inline ,
146+ dispatcher = main_dispatcher ,
91147 ** (exporter_kwargs or {}),
92- dispatcher = dispatcher if use_control_flow_dispatcher else None ,
93148 )
94149
95150 if exporter in ("dynamo" , "onnx-dynamo" ):
@@ -99,6 +154,10 @@ def to_onnx(
99154 assert (
100155 not output_dynamic_shapes
101156 ), f"output_dynamic_shapes not supported for exporter={ exporter !r} "
157+ custom_translation_table = {}
158+ if onnx_plugs :
159+ for plug in onnx_plugs :
160+ custom_translation_table [plug .torch_op ] = plug .onnx_dynamo_converter ()
102161 epo = torch .onnx .export (
103162 mod ,
104163 args = args or tuple (),
@@ -111,9 +170,24 @@ def to_onnx(
111170 verbose = verbose ,
112171 dump_exported_program = bool (save_ep ),
113172 artifacts_dir = os .path .dirname (filename ) if filename else "." ,
173+ custom_translation_table = custom_translation_table ,
114174 ** (exporter_kwargs or {}),
115175 )
116- if optimize :
176+ if not inline and optimize :
177+ ort_fusions .optimize_for_ort (epo .model )
178+
179+ if onnx_plugs :
180+ import onnx_ir as ir
181+ import onnx_ir .passes .common as common_passes
182+
183+ irfunctions = [ir .from_proto (plug .function_proto ) for plug in onnx_plugs ]
184+ for func in irfunctions :
185+ epo .model .functions [func .identifier ()] = func
186+ if inline :
187+ common_passes .InlinePass ()(epo .model )
188+ common_passes .RemoveUnusedOpsetsPass ()(epo .model )
189+
190+ if inline and optimize :
117191 ort_fusions .optimize_for_ort (epo .model )
118192 if filename :
119193 epo .save (filename , external_data = True )
0 commit comments