99
1010
1111@functools .lru_cache
12- def get_kernels () -> Dict [Tuple [str , str , int ], type [torch_ops .OpRun ]]:
12+ def get_kernels () -> Dict [Tuple [str , str , int ], type [torch_ops .OpRunKernel ]]:
1313 """
1414 Retrieves all the available kernels class :class:`TorchOnnxEvaluator`
1515 can use. The full list is the following.
@@ -28,7 +28,7 @@ def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRun]]:
2828 """
2929 res = {}
3030 for _k , v in torch_ops .__dict__ .items ():
31- if isinstance (v , type ) and issubclass (v , torch_ops .OpRun ) and "_" in v .__name__ :
31+ if isinstance (v , type ) and issubclass (v , torch_ops .OpRunKernel ) and "_" in v .__name__ :
3232 name , version = v .__name__ .split ("_" )
3333 domain = getattr (v , "domain" , "" )
3434 res [domain , name , int (version )] = v
@@ -161,11 +161,11 @@ class TorchOnnxEvaluator:
161161 from onnx_diagnostic.helpers import string_type
162162 from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
163163 from onnx_diagnostic.reference import TorchOnnxEvaluator
164- from onnx_diagnostic.reference.torch_ops import OpRun , OpRunTensor
164+ from onnx_diagnostic.reference.torch_ops import OpRunKernel , OpRunTensor
165165
166166 TFLOAT16 = onnx.TensorProto.FLOAT16
167167
168- class LayerNormalizationOrt(OpRun ):
168+ class LayerNormalizationOrt(OpRunKernel ):
169169 "LayerNormalization based on onnxruntime"
170170
171171 def __init__(self, node: onnx.NodeProto, version=None):
@@ -284,11 +284,11 @@ def __init__(
284284 opsets : Optional [Dict [str , int ]] = None ,
285285 local_functions : Optional [Dict [Tuple [str , str ], "TorchOnnxEvaluator" ]] = None ,
286286 verbose : int = 0 ,
287- custom_kernels : Optional [Dict [Tuple [str , str ], type [torch_ops .OpRun ]]] = None ,
287+ custom_kernels : Optional [Dict [Tuple [str , str ], type [torch_ops .OpRunKernel ]]] = None ,
288288 ):
289289 self .providers = providers
290290 self .constants : Dict [str , torch .Tensor ] = {}
291- self .kernels : List [Optional [torch_ops .OpRun ]] = []
291+ self .kernels : List [Optional [torch_ops .OpRunKernel ]] = []
292292 self .functions = local_functions .copy () if local_functions else {}
293293 self .CPU = torch .tensor ([0 ]).to ("cpu" ).device
294294 self .verbose = verbose
0 commit comments