11import unittest
22import torch
3- from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , ignore_warnings
3+ from onnx_diagnostic .ext_test_case import (
4+ ExtTestCase ,
5+ hide_stdout ,
6+ ignore_warnings ,
7+ requires_onnxscript ,
8+ )
49from onnx_diagnostic .reference import ExtendedReferenceEvaluator
510from onnx_diagnostic .helpers .torch_test_helper import is_torchdynamo_exporting
611
@@ -53,6 +58,7 @@ def test_dummy_loop(self):
5358
5459 @hide_stdout ()
5560 @ignore_warnings (UserWarning )
61+ @requires_onnxscript ("0.4" )
5662 def test_export_loop_onnxscript (self ):
5763 class Model (torch .nn .Module ):
5864 def forward (self , images , position ):
@@ -96,7 +102,9 @@ def forward(self, images, position):
96102 dynamo = True ,
97103 fallback = False ,
98104 )
99- ref = ExtendedReferenceEvaluator (name2 )
105+ import onnxruntime
106+
107+ ref = onnxruntime .InferenceSession (name2 , providers = ["CPUExecutionProvider" ])
100108 feeds = dict (images = x .numpy (), position = y .numpy ())
101109 got = ref .run (None , feeds )[0 ]
102110 self .assertEqualArray (expected , got )
@@ -123,7 +131,9 @@ def forward(self, images, position):
123131 filename = name2 ,
124132 dynamic_shapes = {"images" : {0 : "batch" , 1 : "maxdim" }, "position" : {0 : "batch" }},
125133 )
126- ref = ExtendedReferenceEvaluator (name2 )
134+ import onnxruntime
135+
136+ ref = onnxruntime .InferenceSession (name2 , providers = ["CPUExecutionProvider" ])
127137 feeds = dict (images = x .numpy (), position = y .numpy ())
128138 got = ref .run (None , feeds )[0 ]
129139 self .assertEqualArray (expected , got )
0 commit comments