55import onnx .numpy_helper as onh
66import torch
77from onnx_diagnostic .ext_test_case import ExtTestCase
8- from onnx_diagnostic .reference import ExtendedReferenceEvaluator , TorchEvaluator
8+ from onnx_diagnostic .reference import ExtendedReferenceEvaluator , TorchOnnxEvaluator
99from onnx_diagnostic .reference .torch_evaluator import get_kernels
1010
1111
1212TFLOAT = onnx .TensorProto .FLOAT
1313TINT64 = onnx .TensorProto .INT64
1414
1515
16- class TestTorchEvaluator (ExtTestCase ):
16+ class TestTorchOnnxEvaluator (ExtTestCase ):
1717 def test_kernels (self ):
1818 ker = get_kernels ()
1919 self .assertIsInstance (ker , dict )
@@ -36,7 +36,7 @@ def _finalize_test(self, model, *args, atol: float = 0, use_ort: bool = False):
3636 expected = sess .run (None , feeds_numpy )
3737 else :
3838 expected = ExtendedReferenceEvaluator (model ).run (None , feeds_numpy )
39- rt = TorchEvaluator (model )
39+ rt = TorchOnnxEvaluator (model )
4040 got = rt .run (None , feeds )
4141 self .assertEqualAny (expected , [g .detach ().numpy () for g in got ], atol = atol )
4242
@@ -68,7 +68,7 @@ def test_op_binary(self):
6868 )
6969 onnx .checker .check_model (model )
7070
71- rt = TorchEvaluator (model )
71+ rt = TorchOnnxEvaluator (model )
7272 self .assertEqual (5 , len (rt .kernels ))
7373 self .assertEqual (2 , len (rt .constants ))
7474
@@ -144,7 +144,7 @@ def test_op_slice_squeeze(self):
144144 expected = ExtendedReferenceEvaluator (model ).run (
145145 None , {k : v .numpy () for k , v in feeds .items ()}
146146 )
147- rt = TorchEvaluator (model )
147+ rt = TorchOnnxEvaluator (model )
148148 got = rt .run (None , feeds )
149149 self .assertEqualAny (expected , [g .detach ().numpy () for g in got ])
150150
@@ -171,7 +171,7 @@ def test_op_shape(self):
171171 expected = ExtendedReferenceEvaluator (model ).run (
172172 None , {k : v .numpy () for k , v in feeds .items ()}
173173 )
174- rt = TorchEvaluator (model )
174+ rt = TorchOnnxEvaluator (model )
175175 got = rt .run (None , feeds )
176176 self .assertEqualAny (expected , [g .detach ().numpy () for g in got ])
177177
0 commit comments