@@ -50,6 +50,7 @@ class EagerDirectReplacementWithOnnx:
5050 only tensors must be counted
5151 :param name: the name of the custom op, the function name if not specified
5252 :param kwargs: constants
53+ :param verbose: verbose level
5354
5455 Here is an example:
5556
@@ -143,6 +144,7 @@ def __init__(
143144 n_outputs : Optional [int ] = None ,
144145 name : Optional [str ] = None ,
145146 kwargs : Optional [Dict [str , Union [int , float ]]] = None ,
147+ verbose : int = 0 ,
146148 ):
147149 assert isinstance (
148150 function_proto , onnx .FunctionProto
@@ -154,9 +156,13 @@ def __init__(
154156 self .function_proto = function_proto
155157 self .n_inputs = n_inputs
156158 self .n_outputs = n_outputs
157- self .name = name or eager_fn .__qualname__ .replace ("<locals>" , "L" ).replace (
158- "<lambda>" , "l"
159- ).replace ("." , "_" )
159+ self .name = name or (
160+ eager_fn .__name__
161+ if "<" not in eager_fn .__name__
162+ else eager_fn .__qualname__ .replace ("<locals>" , "L" )
163+ .replace ("<lambda>" , "l" )
164+ .replace ("." , "_" )
165+ )
160166 self .kwargs = kwargs
161167 assert kwargs is None or all (isinstance (v , (int , float )) for v in kwargs .values ()), (
162168 f"Only int or floats are allowed for kwargs={ kwargs } , one of them "
@@ -179,7 +185,8 @@ def __init__(
179185 function_proto .domain == self .domain
180186 ), f"Function domain must be { self .domain !r} but it is { function_proto .domain !r} "
181187 self .arg_names = params
182- self .custom_op = self ._registers ()
188+ self .verbose = verbose
189+ self .custom_op = self ._register ()
183190
184191 @property
185192 def domain (self ) -> str :
@@ -202,12 +209,18 @@ def __call__(self, *args):
202209 return self .torch_op (* args )
203210 return self .eager_fn (* args )
204211
205- def _registers (self ):
212+ def _register (self ):
206213 """Registers the custom op."""
207214 inputs = ", " .join ([f"Tensor { p } " for p in self .arg_names ])
208215 schema = f"({ inputs } ) -> Tensor"
209216 if self .n_outputs > 1 :
210217 schema += "[]"
218+ if self .verbose :
219+ print (
220+ f"[EagerDirectReplacementWithOnnx._register] "
221+ f"'torch.ops.{ self .domain } .{ self .name } "
222+ )
223+ print (f"[EagerDirectReplacementWithOnnx._register] schema={ schema } " )
211224 custom_def = torch .library .CustomOpDef (self .domain , self .name , schema , self .eager_fn )
212225 custom_def .register_kernel (None )(self .eager_fn )
213226 custom_def ._abstract_fn = self .shape_fn
0 commit comments