11import unittest
22import numpy as np
3+ import onnx
4+ import onnx .helper as oh
5+ import torch
36from onnx_diagnostic .ext_test_case import ExtTestCase , skipif_ci_windows
4- from onnx_diagnostic .helpers import string_type , string_sig
7+ from onnx_diagnostic .helpers import string_type , string_sig , pretty_onnx , get_onnx_signature
8+
9+ TFLOAT = onnx .TensorProto .FLOAT
510
611
712class TestHelpers (ExtTestCase ):
@@ -19,8 +24,6 @@ def test_string_dict(self):
1924 self .assertEqual (s , "dict(a:A1r1,b:dict(r:float),c:{int})" )
2025
2126 def test_string_type_array (self ):
22- import torch
23-
2427 a = np .array ([1 ], dtype = np .float32 )
2528 t = torch .tensor ([1 ])
2629 obj = {"a" : a , "b" : t }
@@ -30,22 +33,64 @@ def test_string_type_array(self):
3033 self .assertEqual (s , "dict(a:A1s1,b:T7s1)" )
3134
3235 def test_string_sig_f (self ):
33-
3436 def f (a , b = 3 , c = 4 , e = 5 ):
3537 pass
3638
3739 ssig = string_sig (f , {"a" : 1 , "c" : 8 , "b" : 3 })
3840 self .assertEqual (ssig , "f(a=1, c=8)" )
3941
4042 def test_string_sig_cls (self ):
41-
4243 class A :
4344 def __init__ (self , a , b = 3 , c = 4 , e = 5 ):
4445 self .a , self .b , self .c , self .e = a , b , c , e
4546
4647 ssig = string_sig (A (1 , c = 8 ))
4748 self .assertEqual (ssig , "A(a=1, c=8)" )
4849
50+ def test_pretty_onnx (self ):
51+ proto = oh .make_model (
52+ oh .make_graph (
53+ [
54+ oh .make_node ("Sigmoid" , ["Y" ], ["sy" ]),
55+ oh .make_node ("Mul" , ["Y" , "sy" ], ["ysy" ]),
56+ oh .make_node ("Mul" , ["X" , "ysy" ], ["final" ]),
57+ ],
58+ "nd" ,
59+ [
60+ oh .make_tensor_value_info ("X" , TFLOAT , [1 , "b" , "c" ]),
61+ oh .make_tensor_value_info ("Y" , TFLOAT , ["a" , "b" , "c" ]),
62+ ],
63+ [oh .make_tensor_value_info ("final" , TFLOAT , ["a" , "b" , "c" ])],
64+ ),
65+ opset_imports = [oh .make_opsetid ("" , 18 )],
66+ ir_version = 9 ,
67+ )
68+ pretty_onnx (proto , shape_inference = True )
69+ pretty_onnx (proto .graph .input [0 ])
70+ pretty_onnx (proto .graph )
71+ pretty_onnx (proto .graph .node [0 ])
72+
73+ def test_get_onnx_signature (self ):
74+ proto = oh .make_model (
75+ oh .make_graph (
76+ [
77+ oh .make_node ("Sigmoid" , ["Y" ], ["sy" ]),
78+ oh .make_node ("Mul" , ["Y" , "sy" ], ["ysy" ]),
79+ oh .make_node ("Mul" , ["X" , "ysy" ], ["final" ]),
80+ ],
81+ "nd" ,
82+ [
83+ oh .make_tensor_value_info ("X" , TFLOAT , [1 , "b" , "c" ]),
84+ oh .make_tensor_value_info ("Y" , TFLOAT , ["a" , "b" , "c" ]),
85+ ],
86+ [oh .make_tensor_value_info ("final" , TFLOAT , ["a" , "b" , "c" ])],
87+ ),
88+ opset_imports = [oh .make_opsetid ("" , 18 )],
89+ ir_version = 9 ,
90+ )
91+ sig = get_onnx_signature (proto )
92+ self .assertEqual (sig , (("X" , 1 , (1 , "b" , "c" )), ("Y" , 1 , ("a" , "b" , "c" ))))
93+
4994
5095if __name__ == "__main__" :
5196 unittest .main (verbosity = 2 )
0 commit comments