@@ -59,6 +59,43 @@ def forward(self, x):
5959 got = orte .run (None , {name : x .numpy ()})[0 ]
6060 self .assertEqualArray (expected , got )
6161
62+ def test_ort_eval_cond (self ):
63+ import torch
64+
65+ class TwoInputs (torch .nn .Module ):
66+ def forward (self , x , y ):
67+ def true_fn (x , y ):
68+ return torch .sin (x ), torch .cos (x ) + y
69+
70+ def false_fn (x , y ):
71+ return torch .cos (x ), torch .sin (x ) + y
72+
73+ return torch .cond (x .sum () > 0 , true_fn , false_fn , [x , y ])
74+
75+ x , y = torch .rand (5 , 3 ), torch .rand (5 , 3 )
76+ model = TwoInputs ()
77+ onx = to_onnx (model , (x , y ), inline = False )
78+ self .assertEqual (len (onx .functions ), 2 )
79+
80+ # ExtendedReferenceEvaluator
81+ ref = ExtendedReferenceEvaluator (onx )
82+ for _x in (x , - x ):
83+ expected = model (_x , y )
84+ got = ref .run (None , {"x" : _x .detach ().numpy (), "y" : y .detach ().numpy ()})
85+ self .assertEqual (len (expected ), len (got ))
86+ for e , g in zip (expected , got ):
87+ self .assertEqualArray (e , g , atol = 1e-5 )
88+
89+ # OnnxruntimeEvaluator
90+ ref = OnnxruntimeEvaluator (onx )
91+
92+ for _x in (x , - x ):
93+ expected = model (_x , y )
94+ got = ref .run (None , {"x" : _x .detach ().numpy (), "y" : y .detach ().numpy ()})
95+ self .assertEqual (len (expected ), len (got ))
96+ for e , g in zip (expected , got ):
97+ self .assertEqualArray (e , g , atol = 1e-5 )
98+
6299
63100if __name__ == "__main__" :
64101 unittest .main (verbosity = 2 )
0 commit comments