Skip to content

Commit 5262b20

Browse files
committed
add test for test
1 parent a5ca63f commit 5262b20

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,43 @@ def forward(self, x):
5959
got = orte.run(None, {name: x.numpy()})[0]
6060
self.assertEqualArray(expected, got)
6161

62+
def test_ort_eval_cond(self):
63+
import torch
64+
65+
class TwoInputs(torch.nn.Module):
66+
def forward(self, x, y):
67+
def true_fn(x, y):
68+
return torch.sin(x), torch.cos(x) + y
69+
70+
def false_fn(x, y):
71+
return torch.cos(x), torch.sin(x) + y
72+
73+
return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])
74+
75+
x, y = torch.rand(5, 3), torch.rand(5, 3)
76+
model = TwoInputs()
77+
onx = to_onnx(model, (x, y), inline=False)
78+
self.assertEqual(len(onx.functions), 2)
79+
80+
# ExtendedReferenceEvaluator
81+
ref = ExtendedReferenceEvaluator(onx)
82+
for _x in (x, -x):
83+
expected = model(_x, y)
84+
got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()})
85+
self.assertEqual(len(expected), len(got))
86+
for e, g in zip(expected, got):
87+
self.assertEqualArray(e, g, atol=1e-5)
88+
89+
# OnnxruntimeEvaluator
90+
ref = OnnxruntimeEvaluator(onx)
91+
92+
for _x in (x, -x):
93+
expected = model(_x, y)
94+
got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()})
95+
self.assertEqual(len(expected), len(got))
96+
for e, g in zip(expected, got):
97+
self.assertEqualArray(e, g, atol=1e-5)
98+
6299

63100
if __name__ == "__main__":
64101
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)