Skip to content

Commit fb1844b

Browse files
committed
fix ut
1 parent e1a8e1c commit fb1844b

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

_doc/examples/plot_export_tiny_llm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from onnx_diagnostic.helpers import string_type
3434
from onnx_diagnostic.helpers.torch_helper import steal_forward
3535
from onnx_diagnostic.torch_models.llms import get_tiny_llm
36+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3637

3738

3839
MODEL_NAME = "arnir0/Tiny-LLM"
@@ -131,7 +132,11 @@ def _forward_(*args, _f=None, **kwargs):
131132

132133
try:
133134
ep = torch.export.export(
134-
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
135+
untrained_model,
136+
(),
137+
kwargs=cloned_inputs,
138+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
139+
strict=False,
135140
)
136141
print("It worked:")
137142
print(ep)
@@ -166,7 +171,11 @@ def _forward_(*args, _f=None, **kwargs):
166171

167172
try:
168173
ep = torch.export.export(
169-
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
174+
model,
175+
(),
176+
kwargs=cloned_inputs,
177+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
178+
strict=False,
170179
)
171180
print("It worked:")
172181
print(ep)

_unittests/ut_torch_models/test_tiny_llms_bypassed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
import unittest
33
import torch
44
from transformers.cache_utils import DynamicCache
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
66
from onnx_diagnostic.torch_models.llms import get_tiny_llm
77
from onnx_diagnostic.torch_models.llms import get_phi2
88
from onnx_diagnostic.helpers import string_type
99
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1010
from onnx_diagnostic.torch_export_patches import torch_export_patches
11+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1112
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
1213
patched_DynamicCache,
1314
)
1415

1516

1617
class TestTinyLlmBypassed(ExtTestCase):
1718
@ignore_warnings(UserWarning)
19+
@hide_stdout()
1820
def test_export_tiny_llm_2_bypassed(self):
1921
data = get_tiny_llm()
2022
model, inputs = data["model"], data["inputs"]
@@ -50,7 +52,11 @@ def debug():
5052
debug()
5153

5254
ep = torch.export.export(
53-
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
55+
model,
56+
(),
57+
kwargs=inputs,
58+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
59+
strict=False,
5460
)
5561
got = ep.module()(**inputs)
5662
self.assertEqualArrayAny(expected, got)

0 commit comments

Comments
 (0)