@@ -20,6 +20,7 @@ def to_onnx(
2020 optimize : bool = True ,
2121 use_control_flow_dispatcher : bool = False ,
2222 onnx_plugs : Optional [List [EagerDirectReplacementWithOnnx ]] = None ,
23+ inline : bool = True ,
2324) -> Any :
2425 """
2526 Common API for exporters. By default, the models are optimized to use the
@@ -44,6 +45,7 @@ def to_onnx(
4445 :param use_control_flow_dispatcher: use the dispatcher created to supported
4546 custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
4647 :param onnx_plugs: the code was modified to replace some parts with onnx translation
48+ :param inline: inline local functions
4749 :return: the output of the selected exporter, usually a structure including
4850 an onnx model
4951
@@ -63,6 +65,11 @@ def to_onnx(
6365 :func:`onnx_diagnostic.export.control_flow.loop_for` or
6466 :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
6567 """
68+ if exporter_kwargs and "inline" in exporter_kwargs :
69+ assert (
70+ inline == exporter_kwargs ["inline" ]
71+ ), f"Mismatch between inline={ inline } and exporter_kwargs={ exporter_kwargs } "
72+ exporter_kwargs .pop ("inline" )
6673 if exporter == "custom" :
6774 from experimental_experiment .torch_interpreter import (
6875 to_onnx as _to_onnx ,
@@ -90,7 +97,7 @@ def __init__(self):
9097 super ().__init__ ({})
9198
9299 main_dispatcher = MainDispatcher ()
93- if control_flow_dispatcher :
100+ if use_control_flow_dispatcher :
94101 main_dispatcher .registered_functions .update (
95102 control_flow_dispatcher .registered_functions
96103 )
@@ -99,7 +106,6 @@ def __init__(self):
99106 main_dispatcher .registered_functions [plug .target_name ] = (
100107 plug .custom_converter ()
101108 )
102-
103109 else :
104110 main_dispatcher = None
105111
@@ -117,8 +123,9 @@ def __init__(self):
117123 output_dynamic_shapes = output_dynamic_shapes ,
118124 export_options = ExportOptions (save_ep = save_ep ),
119125 options = options ,
120- ** ( exporter_kwargs or {}) ,
126+ inline = inline ,
121127 dispatcher = main_dispatcher ,
128+ ** (exporter_kwargs or {}),
122129 )
123130
124131 if exporter in ("dynamo" , "onnx-dynamo" ):
@@ -147,7 +154,21 @@ def __init__(self):
147154 custom_translation_table = custom_translation_table ,
148155 ** (exporter_kwargs or {}),
149156 )
150- if optimize :
157+ if not inline and optimize :
158+ ort_fusions .optimize_for_ort (epo .model )
159+
160+ if onnx_plugs :
161+ import onnx_ir as ir
162+ import onnx_ir .passes .common as common_passes
163+
164+ irfunctions = [ir .from_proto (plug .function_proto ) for plug in onnx_plugs ]
165+ for func in irfunctions :
166+ epo .model .functions [func .identifier ()] = func
167+ if inline :
168+ common_passes .InlinePass ()(epo .model )
169+ common_passes .RemoveUnusedOpsetsPass ()(epo .model )
170+
171+ if inline and optimize :
151172 ort_fusions .optimize_for_ort (epo .model )
152173 if filename :
153174 epo .save (filename , external_data = True )
0 commit comments