@@ -147,7 +147,7 @@ def test_op_cast(self):
147147 ir_version = 9 ,
148148 opset_imports = [oh .make_opsetid ("" , 18 )],
149149 )
150- self ._finalize_test (model , ( torch .rand ((4 , 5 , 6 , 7 ), dtype = torch .float32 ), ))
150+ self ._finalize_test (model , torch .rand ((4 , 5 , 6 , 7 ), dtype = torch .float32 ))
151151
152152 def test_op_transpose (self ):
153153 model = oh .make_model (
@@ -180,6 +180,46 @@ def test_op_reshape(self):
180180 model , torch .rand ((4 , 5 , 6 , 7 ), dtype = torch .float32 ), torch .tensor ([7 , 4 , 6 , 5 ])
181181 )
182182
183+ def test_op_matmul (self ):
184+ model = oh .make_model (
185+ oh .make_graph (
186+ [oh .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])],
187+ "dummy" ,
188+ [
189+ oh .make_tensor_value_info ("X" , TFLOAT , ["a" , "b" , "c" , "d" ]),
190+ oh .make_tensor_value_info ("Y" , TFLOAT , ["a" , "b" , "d" , "f" ]),
191+ ],
192+ [oh .make_tensor_value_info ("Z" , TFLOAT , ["a" , "b" , "c" , "f" ])],
193+ ),
194+ ir_version = 9 ,
195+ opset_imports = [oh .make_opsetid ("" , 18 )],
196+ )
197+ self ._finalize_test (
198+ model ,
199+ torch .rand ((4 , 5 , 6 , 7 ), dtype = torch .float32 ),
200+ torch .rand ((4 , 5 , 7 , 11 ), dtype = torch .float32 ),
201+ )
202+
203+ def test_op_unsqueeze (self ):
204+ model = oh .make_model (
205+ oh .make_graph (
206+ [oh .make_node ("Unsqueeze" , ["X" , "axes" ], ["Z" ])],
207+ "dummy" ,
208+ [
209+ oh .make_tensor_value_info ("X" , TFLOAT , ["a" , "b" , 1 , "d" ]),
210+ oh .make_tensor_value_info ("axes" , TINT64 , ["s" ]),
211+ ],
212+ [oh .make_tensor_value_info ("Z" , TFLOAT , ["a" , "b" , "d" ])],
213+ ),
214+ ir_version = 9 ,
215+ opset_imports = [oh .make_opsetid ("" , 18 )],
216+ )
217+ self ._finalize_test (
218+ model ,
219+ torch .rand ((4 , 5 , 1 , 7 ), dtype = torch .float32 ),
220+ torch .tensor ([2 ], dtype = torch .int64 ),
221+ )
222+
183223
184224if __name__ == "__main__" :
185225 unittest .main (verbosity = 2 )
0 commit comments