33from typing import Any , Callable , Dict , List , Optional , Tuple , Union
44import onnx
55import torch
6- from ..helpers import max_diff
6+ from ..helpers import max_diff , string_type
77from ..helpers .torch_helper import torch_dtype_to_onnx_dtype
88from ..reference import OnnxruntimeEvaluator
99
@@ -50,6 +50,7 @@ class EagerDirectReplacementWithOnnx:
5050 only tensors must be counted
5151 :param name: the name of the custom op, the function name if not specified
5252 :param kwargs: constants parameters with their default values
53+ :param version_selector: selects the version based on the arguments
5354 :param verbose: verbose level
5455
5556 Here is an example:
@@ -139,21 +140,28 @@ def __init__(
139140 self ,
140141 eager_fn : Callable [[TUPLE_TENSORS ], TUPLE_TENSORS ],
141142 shape_fn : Callable [[TUPLE_TENSORS ], TUPLE_TENSORS ],
142- function_proto : onnx .FunctionProto ,
143+ function_proto : Union [ onnx .FunctionProto , Dict [ Any , onnx . FunctionProto ]] ,
143144 n_inputs : Optional [int ] = None ,
144145 n_outputs : Optional [int ] = None ,
145146 name : Optional [str ] = None ,
146147 kwargs : Optional [Dict [str , Union [int , float ]]] = None ,
147148 verbose : int = 0 ,
149+ version_selector : Optional [Callable [[Any ], Any ]] = None ,
148150 ):
149- assert isinstance (
150- function_proto , onnx .FunctionProto
151+ assert isinstance (function_proto , onnx .FunctionProto ) or (
152+ isinstance (function_proto , dict )
153+ or all (isinstance (v , onnx .FunctionProto ) for v in function_proto .values ())
151154 ), f"Unexpected type { type (function_proto )} for function_proto"
152155 assert isinstance (n_inputs , int ), f"not implemented yet when n_inputs={ n_inputs } "
153- assert isinstance (n_outputs , int ), f"not implemented yet when n_inputs ={ n_outputs } "
156+ assert isinstance (n_outputs , int ), f"not implemented yet when n_outputs ={ n_outputs } "
154157 self .eager_fn = eager_fn
155158 self .shape_fn = shape_fn
156- self .function_proto = function_proto
159+ self ._function_proto = (
160+ function_proto if isinstance (function_proto , onnx .FunctionProto ) else None
161+ )
162+ self ._function_proto_versioned = (
163+ function_proto if isinstance (function_proto , dict ) else {}
164+ )
157165 self .n_inputs = n_inputs
158166 self .n_outputs = n_outputs
159167 self .name = name or (
@@ -170,24 +178,72 @@ def __init__(
170178 )
171179 sig = inspect .signature (self .eager_fn )
172180 params = list (sig .parameters )
173- assert (
174- len (params ) >= n_inputs
175- ), f"{ self .eager_fn } accepts { params } as parameters < n_inputs={ n_inputs } "
176- assert n_inputs == len (function_proto .input ), (
177- f"Input mismatch n_inputs={ n_inputs } but "
178- f"function_proto.input={ function_proto .input } "
179- )
180- assert n_outputs == len (function_proto .output ), (
181- f"Output mismatch n_outputs={ n_outputs } but "
182- f"function_proto.output={ function_proto .output } "
183- )
184- assert (
185- function_proto .domain == self .domain
186- ), f"Function domain must be { self .domain !r} but it is { function_proto .domain !r} "
187181 self .args_name = [p for p in params if p not in self .kwargs ]
188182 self .kwargs_name = [p for p in params if p in self .kwargs ]
189183 self .verbose = verbose
190184 self .custom_op = self ._register ()
185+ self .version_selector = version_selector
186+ self ._check_protos (params )
187+
188+ def _check_protos (self , params ):
189+ assert (
190+ len (params ) >= self .n_inputs
191+ ), f"{ self .eager_fn } accepts { params } as parameters < n_inputs={ self .n_inputs } "
192+
193+ # one proto
194+ assert self ._function_proto is None or self .n_inputs == len (
195+ self ._function_proto .input
196+ ), (
197+ f"Input mismatch n_inputs={ self .n_inputs } but "
198+ f"function_proto.input={ self ._function_proto .input } "
199+ )
200+ assert self ._function_proto is None or self .n_outputs == len (
201+ self ._function_proto .output
202+ ), (
203+ f"Output mismatch n_outputs={ self .n_outputs } but "
204+ f"function_proto.output={ self ._function_proto .output } "
205+ )
206+ assert self ._function_proto is None or (
207+ self ._function_proto .domain == self .domain
208+ ), f"Function domain must be { self .domain !r} but it is { self ._function_proto .domain !r} "
209+
210+ # multiple protos
211+ assert all (
212+ self .n_inputs == len (v .input ) for v in self ._function_proto_versioned .values ()
213+ ), f"Output mismatch n_inputs={ self .n_inputs } but one verion is wrong"
214+ assert all (
215+ self .n_outputs == len (v .output ) for v in self ._function_proto_versioned .values ()
216+ ), f"Output mismatch n_outputs={ self .n_outputs } but one verion is wrong"
217+ assert all (
218+ v .domain == self .domain for v in self ._function_proto_versioned .values ()
219+ ), f"Function domain must be { self .domain !r} but it is different in one version"
220+ assert (
221+ not self ._function_proto_versioned or self .version_selector
222+ ), "version_selector is needed when multiple protos are given."
223+
224+ def get_function_proto (self , * args ) -> onnx .FunctionProto :
225+ """Returns the correct version based on the inputs."""
226+ if self ._function_proto :
227+ return self ._function_proto
228+ if (
229+ len (args ) == 1
230+ and isinstance (args [0 ], (int , str ))
231+ and args [0 ] in self ._function_proto_versioned
232+ ):
233+ return self ._function_proto_versioned [args [0 ]]
234+ try :
235+ key = self .version_selector (* args )
236+ except (ValueError , AttributeError ) as e :
237+ raise AssertionError (
238+ f"Unable to select a version, fails to get a key, available="
239+ f"{ set (self ._function_proto_versioned )} , "
240+ f"args={ string_type (args ,with_shape = True )} "
241+ ) from e
242+ assert key in self ._function_proto_versioned , (
243+ f"Unable to select a version, key={ key } , available="
244+ f"{ set (self ._function_proto_versioned )} , args={ string_type (args ,with_shape = True )} "
245+ )
246+ return self ._function_proto_versioned [key ]
191247
192248 @property
193249 def domain (self ) -> str :
@@ -291,7 +347,7 @@ def verify(
291347 assert engine is None , f"Not implemented yet with engine={ engine !r} "
292348 ags , kws = self ._make_args_kwargs (* args , ** kwargs )
293349 sess = OnnxruntimeEvaluator (
294- self .function_proto ,
350+ self .get_function_proto ( * args ) ,
295351 whole = True ,
296352 dump_onnx_model = dump_onnx_model ,
297353 function_kwargs = kws ,
@@ -324,16 +380,15 @@ def converter(
324380 * args ,
325381 ** kwargs ,
326382 ) -> Any :
327- if not g .has_local_function (
328- self .function_proto .name , domain = self .function_proto .domain
329- ):
330- g .add_function (self .function_proto )
383+ function_proto = self .get_function_proto (g .get_type (args [0 ]))
384+ if not g .has_local_function (function_proto .name , domain = function_proto .domain ):
385+ g .add_function (function_proto )
331386 ags , kws = self ._make_args_kwargs (* args , ** kwargs )
332387 res = g .make_node (
333- self . function_proto .name ,
388+ function_proto .name ,
334389 ags ,
335390 outputs ,
336- domain = self . function_proto .domain ,
391+ domain = function_proto .domain ,
337392 name = self .target_name ,
338393 ** kws ,
339394 )
@@ -356,41 +411,46 @@ def onnx_dynamo_converter(self) -> Callable:
356411 """
357412 import onnxscript
358413
359- onnx_plug_op = onnxscript .values .Opset (domain = self .function_proto .domain , version = 1 )
360- schema = onnx_plug_op [self .function_proto .name ]
361- if schema is None :
362- all_types = [
363- "tensor(float)" ,
364- "tensor(float16)" ,
365- "tensor(bfloat16)" ,
366- "tensor(double)" ,
367- "tensor(int64)" ,
368- "tensor(int32)" ,
369- ]
370- type_constraints = []
371- for i in range (self .n_inputs ):
372- type_constraints .append ((f"T{ i } " , all_types , "" ))
373- for i in range (self .n_outputs ):
374- type_constraints .append ((f"U{ i } " , all_types , "" ))
375- schema = onnx .defs .OpSchema (
376- self .function_proto .name ,
377- self .function_proto .domain ,
378- 1 ,
379- inputs = [
380- onnx .defs .OpSchema .FormalParameter (f"arg_{ i } " , f"T{ i } " )
381- for i in range (self .n_inputs )
382- ],
383- outputs = [
384- onnx .defs .OpSchema .FormalParameter (f"res_{ i } " , f"U{ i } " )
385- for i in range (self .n_outputs )
386- ],
387- type_constraints = type_constraints ,
388- )
389- onnx .defs .register_schema (schema )
390- op = onnxscript .values .Op (onnx_plug_op , self .function_proto .name , schema )
414+ onnx_plug_op = onnxscript .values .Opset (domain = self .domain , version = 1 )
415+
416+ def get_proto (* args ):
417+ function_proto = self .get_function_proto ()
418+ schema = onnx_plug_op [function_proto .name ]
419+ if schema is None :
420+ all_types = [
421+ "tensor(float)" ,
422+ "tensor(float16)" ,
423+ "tensor(bfloat16)" ,
424+ "tensor(double)" ,
425+ "tensor(int64)" ,
426+ "tensor(int32)" ,
427+ ]
428+ type_constraints = []
429+ for i in range (self .n_inputs ):
430+ type_constraints .append ((f"T{ i } " , all_types , "" ))
431+ for i in range (self .n_outputs ):
432+ type_constraints .append ((f"U{ i } " , all_types , "" ))
433+ schema = onnx .defs .OpSchema (
434+ function_proto .name ,
435+ function_proto .domain ,
436+ 1 ,
437+ inputs = [
438+ onnx .defs .OpSchema .FormalParameter (f"arg_{ i } " , f"T{ i } " )
439+ for i in range (self .n_inputs )
440+ ],
441+ outputs = [
442+ onnx .defs .OpSchema .FormalParameter (f"res_{ i } " , f"U{ i } " )
443+ for i in range (self .n_outputs )
444+ ],
445+ type_constraints = type_constraints ,
446+ )
447+ onnx .defs .register_schema (schema )
448+ op = onnxscript .values .Op (onnx_plug_op , function_proto .name , schema )
449+ return op
391450
392451 def converter (* cargs , ** ckwargs ):
393452 ags , kws = self ._make_args_kwargs (* cargs , ** ckwargs )
453+ op = get_proto (* cargs )
394454 return op (* ags , n_outputs = self .n_outputs , ** kws )
395455
396456 return onnxscript .values .TracedOnnxFunction (onnx_plug_op , converter )
0 commit comments