11import unittest
2+ from typing import Optional
23import numpy as np
34import onnx
45import onnx .helper as oh
@@ -22,7 +23,7 @@ def test_kernels(self):
2223 kernel = ker [key ]
2324 self .assertEqual ("Add_1" , kernel .__name__ )
2425
25- def _finalize_test (self , model , * args ):
26+ def _finalize_test (self , model , * args , atol : Optional [ float ] = None ):
2627 onnx .checker .check_model (model )
2728 feeds = dict (zip ([i .name for i in model .graph .input ], args ))
2829
@@ -31,7 +32,7 @@ def _finalize_test(self, model, *args):
3132 )
3233 rt = TorchEvaluator (model )
3334 got = rt .run (None , feeds )
34- self .assertEqualAny (expected , [g .detach ().numpy () for g in got ])
35+ self .assertEqualAny (expected , [g .detach ().numpy () for g in got ], atol = atol )
3536
3637 def test_op_binary (self ):
3738 model = oh .make_model (
@@ -260,6 +261,21 @@ def test_op_gather(self):
260261 torch .tensor ([0 , 1 , 3 ], dtype = torch .int64 ),
261262 )
262263
264+ def test_op_softmax (self ):
265+ model = oh .make_model (
266+ oh .make_graph (
267+ [oh .make_node ("Softmax" , ["X" ], ["Z" ], axis = 0 )],
268+ "dummy" ,
269+ [oh .make_tensor_value_info ("X" , TFLOAT , ["a" , "b" , "c" ])],
270+ [oh .make_tensor_value_info ("Z" , TFLOAT , ["a" , "b" , "c" ])],
271+ ),
272+ ir_version = 9 ,
273+ opset_imports = [oh .make_opsetid ("" , 18 )],
274+ )
275+ self ._finalize_test (
276+ model , torch .abs (torch .rand (3 , 4 , 5 , dtype = torch .float32 )), atol = 1e-6
277+ )
278+
263279
264280if __name__ == "__main__" :
265281 unittest .main (verbosity = 2 )
0 commit comments