Skip to content

Commit 951257d

Browse files
committed
fix ut
1 parent 5138b69 commit 951257d

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

_unittests/ut_export/test_jit.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
ignore_warnings,
7+
requires_onnxscript,
8+
)
49
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
510
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
611

@@ -53,6 +58,7 @@ def test_dummy_loop(self):
5358

5459
@hide_stdout()
5560
@ignore_warnings(UserWarning)
61+
@requires_onnxscript("0.4")
5662
def test_export_loop_onnxscript(self):
5763
class Model(torch.nn.Module):
5864
def forward(self, images, position):
@@ -96,7 +102,9 @@ def forward(self, images, position):
96102
dynamo=True,
97103
fallback=False,
98104
)
99-
ref = ExtendedReferenceEvaluator(name2)
105+
import onnxruntime
106+
107+
ref = onnxruntime.InferenceSession(name2, providers=["CPUExecutionProvider"])
100108
feeds = dict(images=x.numpy(), position=y.numpy())
101109
got = ref.run(None, feeds)[0]
102110
self.assertEqualArray(expected, got)
@@ -123,7 +131,9 @@ def forward(self, images, position):
123131
filename=name2,
124132
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
125133
)
126-
ref = ExtendedReferenceEvaluator(name2)
134+
import onnxruntime
135+
136+
ref = onnxruntime.InferenceSession(name2, providers=["CPUExecutionProvider"])
127137
feeds = dict(images=x.numpy(), position=y.numpy())
128138
got = ref.run(None, feeds)[0]
129139
self.assertEqualArray(expected, got)

0 commit comments

Comments
 (0)