@@ -148,14 +148,31 @@ def make_custom_loop_for(
148148 custom_def ._abstract_fn = lambda * _args , _o = body_outputs : (
149149 tuple ([torch .empty_like (s ) for s in _o ]) if len (_o ) > 1 else torch .empty_like (_o [0 ])
150150 )
151- onx = convert_into_onnx (body_gm , args )
151+
152+ def _make_onx (
153+ body_gm = body_gm , args = args , target_opset = None , verbose = 0 , exporter_kwargs = None
154+ ):
155+ return convert_into_onnx (
156+ body_gm ,
157+ args ,
158+ exporter_kwargs = exporter_kwargs ,
159+ target_opset = target_opset ,
160+ verbose = verbose ,
161+ )
162+
152163 to_register = (
153164 custom_def ,
154- onx ,
165+ _make_onx ,
155166 (
156- lambda g , sts , outputs , * args , body = onx , reduction_dim = reduction_dim , name = name : (
167+ lambda g , sts , outputs , * args , bc = _make_onx , rd = reduction_dim , name = name : (
157168 convert_custom_loop_into_onnx (
158- g , sts , outputs , * args , body = body , reduction_dim = reduction_dim , name = name
169+ g ,
170+ sts ,
171+ outputs ,
172+ * args ,
173+ body_callable = bc ,
174+ reduction_dim = rd ,
175+ name = name ,
159176 )
160177 )
161178 ),
@@ -173,7 +190,7 @@ def convert_custom_loop_into_onnx(
173190 sts : Dict [str , Any ],
174191 outputs : List [str ],
175192 * args : str ,
176- body : onnx .GraphProto ,
193+ body_callable : Callable [..., onnx .ModelProto ] ,
177194 reduction_dim : Optional [Sequence [int ]] = None ,
178195 name : str = "loop_for" ,
179196) -> Union [str , List [str ]]:
@@ -190,6 +207,14 @@ def convert_custom_loop_into_onnx(
190207 :param name: to give the onnx nodes a name
191208 :return: output names
192209 """
210+ assert body_callable is not None , "body_callable cannot be None"
211+ # This should be part of a public API.
212+ body = body_callable (
213+ target_opset = g .main_opset ,
214+ verbose = g .verbose ,
215+ exporter_kwargs = {"options" : g .optimization_options },
216+ )
217+
193218 graph = body .graph if isinstance (body , onnx .ModelProto ) else body
194219 assert isinstance (
195220 graph , onnx .GraphProto
@@ -261,19 +286,33 @@ def convert_custom_loop_into_onnx(
261286
262287
263288def convert_into_onnx (
264- body_gm : torch .fx .GraphModule , args : Sequence [torch .Tensor ]
289+ body_gm : torch .fx .GraphModule ,
290+ args : Sequence [torch .Tensor ],
291+ target_opset : Optional [int ] = None ,
292+ verbose : int = 0 ,
293+ exporter_kwargs : Optional [Dict [str , Any ]] = None ,
265294) -> onnx .ModelProto :
266295 """
267296 Converts a torch.fx.GraphModule into ONNX.
268297 It returns a ModelProto.
269298
270299 :param body_gm: a torch.fx.GraphModule
271300 :param args: arguments known at export time
301+ :param target_opset: targetted opset
302+ :param verbose: verbosity level
303+ :param exporter_kwargs: additional exporter arguments
272304 :return: a ModelProto
273305 """
274306 # This does not work with onnx-dynamo.
275307 # opset still needs to be defined
276- container = to_onnx (body_gm , args , exporter = "custom" )
308+ container = to_onnx (
309+ body_gm ,
310+ args ,
311+ exporter = "custom" ,
312+ exporter_kwargs = exporter_kwargs ,
313+ target_opset = target_opset ,
314+ verbose = verbose ,
315+ )
277316 return container .model_proto
278317
279318
0 commit comments