@@ -22,13 +22,20 @@ def test_kernels(self):
2222 kernel = ker [key ]
2323 self .assertEqual ("Add_1" , kernel .__name__ )
2424
25- def _finalize_test (self , model , * args , atol : float = 0 ):
25+ def _finalize_test (self , model , * args , atol : float = 0 , use_ort : bool = False ):
2626 onnx .checker .check_model (model )
2727 feeds = dict (zip ([i .name for i in model .graph .input ], args ))
28+ feeds_numpy = {k : v .numpy () for k , v in feeds .items ()}
2829
29- expected = ExtendedReferenceEvaluator (model ).run (
30- None , {k : v .numpy () for k , v in feeds .items ()}
31- )
30+ if use_ort :
31+ import onnxruntime
32+
33+ sess = onnxruntime .InferenceSession (
34+ model .SerializeToString (), providers = ["CPUExecutionProvider" ]
35+ )
36+ expected = sess .run (None , feeds_numpy )
37+ else :
38+ expected = ExtendedReferenceEvaluator (model ).run (None , feeds_numpy )
3239 rt = TorchEvaluator (model )
3340 got = rt .run (None , feeds )
3441 self .assertEqualAny (expected , [g .detach ().numpy () for g in got ], atol = atol )
@@ -453,6 +460,52 @@ def test_op_reduce_sum(self):
453460 atol = 1e-5 ,
454461 )
455462
463+ def test_op_where (self ):
464+ model = oh .make_model (
465+ oh .make_graph (
466+ [
467+ oh .make_node ("Greater" , ["X" , "Y" ], ["cond" ]),
468+ oh .make_node ("Where" , ["cond" , "X" , "Y" ], ["Z" ]),
469+ ],
470+ "dummy" ,
471+ [
472+ oh .make_tensor_value_info ("X" , TFLOAT , ["a" , "b" , "c" ]),
473+ oh .make_tensor_value_info ("Y" , TFLOAT , ["a" , "b" , "c" ]),
474+ ],
475+ [oh .make_tensor_value_info ("Z" , TFLOAT , ["a" , "b" , "c" ])],
476+ ),
477+ ir_version = 9 ,
478+ opset_imports = [oh .make_opsetid ("" , 18 )],
479+ )
480+ self ._finalize_test (
481+ model ,
482+ torch .rand (3 , 4 , 5 , dtype = torch .float32 ),
483+ torch .rand (3 , 4 , 5 , dtype = torch .float32 ),
484+ )
485+
486+ def test_op_layer_normalization (self ):
487+ model = oh .make_model (
488+ oh .make_graph (
489+ [oh .make_node ("LayerNormalization" , ["X" , "W" , "B" ], ["Z" ], axis = - 1 )],
490+ "dummy" ,
491+ [
492+ oh .make_tensor_value_info ("X" , TFLOAT , ["a" , "b" , "c" ]),
493+ oh .make_tensor_value_info ("W" , TFLOAT , []),
494+ oh .make_tensor_value_info ("B" , TFLOAT , []),
495+ ],
496+ [oh .make_tensor_value_info ("Z" , TFLOAT , ["a" , "b" , "c" ])],
497+ ),
498+ ir_version = 9 ,
499+ opset_imports = [oh .make_opsetid ("" , 18 )],
500+ )
501+ self ._finalize_test (
502+ model ,
503+ torch .rand (3 , 4 , 5 , dtype = torch .float32 ),
504+ torch .abs (torch .rand (5 , dtype = torch .float32 )),
505+ torch .rand (5 , dtype = torch .float32 ),
506+ use_ort = True ,
507+ )
508+
456509
457510if __name__ == "__main__" :
458511 unittest .main (verbosity = 2 )
0 commit comments