@@ -27,7 +27,7 @@ class VerifyResult:
2727 """
2828
2929 eager_outputs : TUPLE_TENSORS
30- onnx_output : TUPLE_TENSORS
30+ onnx_outputs : TUPLE_TENSORS
3131 diffs : Tuple [Dict [str , float ], ...]
3232
3333
@@ -238,20 +238,30 @@ def _register(self):
238238 custom_def .register_kernel (None )(self .eager_fn )
239239 custom_def ._abstract_fn = self .shape_fn
240240
241- def verify (self , * args , engine : Optional [Callable ] = None ) -> VerifyResult :
241+ def verify (
242+ self ,
243+ * args ,
244+ engine : Optional [Callable ] = None ,
245+ dump_onnx_model : Optional [str ] = None ,
246+ ** kwargs ,
247+ ) -> VerifyResult :
242248 """
243249 Verifies that the eager mode is equivalent to the onnx function given
244250 as a replacements. This function evaluates `eager_fn`, checks that the shapes
245251 are equivalent to the ones given by `shape_fn`, and finally evaluates the
246252 onnx translation if the previous did not fail.
247253
248254 :param args: function inputs
255+ :param kwargs: arguments for eager_fn
249256 :param engine: by default an instance of
250257 :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258+ :param dump_onnx_model: to dump the onnx model used to verify
259+ eager and onnx produce the same results
260+ :param kwargs: additional arguments to the function
251261 :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
252262 """
253- expected = self .eager_fn (* args )
254- shapes = self .shape_fn (* args )
263+ expected = self .eager_fn (* args , ** kwargs )
264+ shapes = self .shape_fn (* args , ** kwargs )
255265 if isinstance (expected , torch .Tensor ):
256266 expected = (expected ,)
257267 assert isinstance (shapes , torch .Tensor ), (
@@ -279,11 +289,23 @@ def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult:
279289
280290 # Now the ONNX execution.
281291 assert engine is None , f"Not implemented yet with engine={ engine !r} "
282- sess = OnnxruntimeEvaluator (self .function_proto )
283- feeds = dict (zip (sess .input_names , args ))
292+ ags , kws = self ._make_args_kwargs (* args , ** kwargs )
293+ sess = OnnxruntimeEvaluator (
294+ self .function_proto ,
295+ whole = True ,
296+ dump_onnx_model = dump_onnx_model ,
297+ function_kwargs = kws ,
298+ )
299+ feeds = dict (zip (sess .input_names , ags ))
284300 got = sess .run (None , feeds )
285- diffs = tuple (max_diff (e , g ) for e , g in zip (expected , got ))
286- return VerifyResult (eager_outputs = expected , onnx_output = tuple (got ), diffs = diffs ) # type: ignore[arg-type]
301+ diffs = tuple (max_diff (e , g , hist = [0.1 , 0.01 ]) for e , g in zip (expected , got ))
302+ return VerifyResult (eager_outputs = expected , onnx_outputs = tuple (got ), diffs = diffs ) # type: ignore[arg-type]
303+
304+ def _make_args_kwargs (self , * args , ** kwargs ):
305+ ags = args [: len (self .args_name )]
306+ kws = dict (zip (self .kwargs_name , args [len (self .args_name ) :]))
307+ kws .update (kwargs )
308+ return ags , kws
287309
288310 def custom_converter (
289311 self ,
@@ -306,9 +328,7 @@ def converter(
306328 self .function_proto .name , domain = self .function_proto .domain
307329 ):
308330 g .add_function (self .function_proto )
309- ags = args [: len (self .args_name )]
310- kws = dict (zip (self .kwargs_name , args [len (self .args_name ) :]))
311- kws .update (kwargs )
331+ ags , kws = self ._make_args_kwargs (* args , ** kwargs )
312332 res = g .make_node (
313333 self .function_proto .name ,
314334 ags ,
@@ -369,7 +389,8 @@ def onnx_dynamo_converter(self) -> Callable:
369389 onnx .defs .register_schema (schema )
370390 op = onnxscript .values .Op (onnx_plug_op , self .function_proto .name , schema )
371391
372- def converter (* cargs ):
373- return op (* cargs , n_outputs = self .n_outputs )
392+ def converter (* cargs , ** ckwargs ):
393+ ags , kws = self ._make_args_kwargs (* cargs , ** ckwargs )
394+ return op (* ags , n_outputs = self .n_outputs , ** kws )
374395
375396 return onnxscript .values .TracedOnnxFunction (onnx_plug_op , converter )
0 commit comments