11import inspect
22from dataclasses import dataclass
3- from typing import Any , Callable , Dict , List , Optional , Tuple
3+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
44import onnx
55import torch
66from ..helpers import max_diff
@@ -49,6 +49,7 @@ class EagerDirectReplacementWithOnnx:
4949 :param n_outputs: same for the number of outputs,
5050 only tensors must be counted
5151 :param name: the name of the custom op, the function name if not specified
52+ :param kwargs: constants
5253
5354 Here is an example:
5455
@@ -141,6 +142,7 @@ def __init__(
141142 n_inputs : Optional [int ] = None ,
142143 n_outputs : Optional [int ] = None ,
143144 name : Optional [str ] = None ,
145+ kwargs : Optional [Dict [str , Union [int , float ]]] = None ,
144146 ):
145147 assert isinstance (
146148 function_proto , onnx .FunctionProto
@@ -152,7 +154,14 @@ def __init__(
152154 self .function_proto = function_proto
153155 self .n_inputs = n_inputs
154156 self .n_outputs = n_outputs
155- self .name = name or eager_fn .__name__
157+ self .name = name or eager_fn .__qualname__ .replace ("<local>" , "L" ).replace (
158+ "<lambda>" , "l"
159+ ).replace ("." , "_" )
160+ self .kwargs = kwargs
161+ assert kwargs is None or all (isinstance (v , (int , float )) for v in kwargs .values ()), (
162+ f"Only int or floats are allowed for kwargs={ kwargs } , one of them "
163+ f"does not respect that constraint."
164+ )
156165 sig = inspect .signature (self .eager_fn )
157166 params = list (sig .parameters )
158167 assert (
@@ -190,7 +199,7 @@ def torch_op(self) -> Callable:
190199 def __call__ (self , * args ):
191200 """Calls eager_fn or shape_fn if the model is being exported."""
192201 if is_exporting ():
193- return self .shape_fn (* args )
202+ return self .torch_op (* args )
194203 return self .eager_fn (* args )
195204
196205 def _registers (self ):
@@ -266,10 +275,16 @@ def converter(
266275 outputs : List [str ],
267276 * args ,
268277 ) -> Any :
269- if not g .has_local_function (self .name , self .domain ):
278+ if not g .has_local_function (
279+ self .function_proto .name , domain = self .function_proto .domain
280+ ):
270281 g .add_function (self .function_proto )
271282 res = g .make_node (
272- self .name , args , outputs , domain = self .domain , name = self .target_name
283+ self .function_proto .name ,
284+ args ,
285+ outputs ,
286+ domain = self .function_proto .domain ,
287+ name = self .target_name ,
273288 )
274289 if not sts :
275290 new_shapes = self .shape_fn (* args )
@@ -290,8 +305,8 @@ def onnx_dynamo_converter(self) -> Callable:
290305 """
291306 import onnxscript
292307
293- onnx_plug_op = onnxscript .values .Opset (domain = self .domain , version = 1 )
294- schema = onnx_plug_op [self .name ]
308+ onnx_plug_op = onnxscript .values .Opset (domain = self .function_proto . domain , version = 1 )
309+ schema = onnx_plug_op [self .function_proto . name ]
295310 if schema is None :
296311 all_types = [
297312 "tensor(float)" ,
@@ -307,8 +322,8 @@ def onnx_dynamo_converter(self) -> Callable:
307322 for i in range (self .n_outputs ):
308323 type_constraints .append ((f"U{ i } " , all_types , "" ))
309324 schema = onnx .defs .OpSchema (
310- self .name ,
311- self .domain ,
325+ self .function_proto . name ,
326+ self .function_proto . domain ,
312327 1 ,
313328 inputs = [
314329 onnx .defs .OpSchema .FormalParameter (f"arg_{ i } " , f"T{ i } " )
@@ -321,7 +336,7 @@ def onnx_dynamo_converter(self) -> Callable:
321336 type_constraints = type_constraints ,
322337 )
323338 onnx .defs .register_schema (schema )
324- op = onnxscript .values .Op (onnx_plug_op , self .name , schema )
339+ op = onnxscript .values .Op (onnx_plug_op , self .function_proto . name , schema )
325340
326341 def converter (* cargs ):
327342 return op (* cargs , n_outputs = self .n_outputs )
0 commit comments