Skip to content

Commit 1b3052e

Browse files
committed
fix
1 parent 159a9fc commit 1b3052e

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

_unittests/ut_helpers/test_ort_session_tinyllm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,24 @@ def test_ort_value_more(self):
7373
@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
7474
@hide_stdout()
7575
def test_check_allruntimes_on_tiny_llm(self):
76+
try:
77+
from experimental_experiment.torch_interpreter import to_onnx
78+
except ImportError:
79+
to_onnx = None
80+
7681
data = get_tiny_llm()
7782
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
7883
expected = model(**copy.deepcopy(inputs))
7984

8085
with torch_export_patches(patch_transformers=True):
81-
ep = torch.onnx.export(
82-
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds, dynamo=True
83-
)
86+
if to_onnx:
87+
proto = to_onnx(model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds)
88+
else:
89+
stop
90+
proto = torch.onnx.export(
91+
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds, dynamo=True
92+
).model_proto
8493

85-
proto = ep.model_proto
8694
self.dump_onnx("test_check_allruntimes_on_tiny_llm.onnx", proto)
8795
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
8896
sess = onnxruntime.InferenceSession(

0 commit comments

Comments
 (0)