Skip to content

Commit 21b7e23

Browse files
committed
bypass
1 parent 1e0e456 commit 21b7e23

File tree

2 files changed

+18
-24
lines changed

2 files changed

+18
-24
lines changed
Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import (
4-
ExtTestCase,
5-
ignore_warnings,
6-
requires_transformers,
7-
requires_python,
8-
)
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
94
from onnx_diagnostic.torch_models.llms import get_phi2
105
from onnx_diagnostic.helpers import string_type
11-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
126

137

148
class TestLlmPhi(ExtTestCase):
@@ -29,23 +23,6 @@ def test_export_phi2_1(self):
2923
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
3024
assert ep
3125

32-
@ignore_warnings(UserWarning)
33-
@requires_python((3, 12))
34-
def test_export_phi2_2_bypassed(self):
35-
data = get_phi2(num_hidden_layers=2)
36-
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
37-
self.assertEqual(
38-
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
39-
)
40-
with bypass_export_some_errors(patch_transformers=True) as modificator:
41-
inputs = modificator(inputs)
42-
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
43-
assert ep
44-
with bypass_export_some_errors(patch_transformers=True) as modificator:
45-
inputs = modificator(inputs)
46-
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
47-
assert ep
48-
4926

5027
if __name__ == "__main__":
5128
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_tiny_llms_bypassed.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers.cache_utils import DynamicCache
55
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
66
from onnx_diagnostic.torch_models.llms import get_tiny_llm
7+
from onnx_diagnostic.torch_models.llms import get_phi2
78
from onnx_diagnostic.helpers import string_type
89
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
910
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
@@ -53,6 +54,22 @@ def debug():
5354
got = ep.module()(**inputs)
5455
self.assertEqualArrayAny(expected, got)
5556

57+
@ignore_warnings(UserWarning)
58+
def test_export_phi2_2_bypassed(self):
59+
data = get_phi2(num_hidden_layers=2)
60+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
61+
self.assertEqual(
62+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
63+
)
64+
with bypass_export_some_errors(patch_transformers=True) as modificator:
65+
inputs = modificator(inputs)
66+
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
67+
assert ep
68+
with bypass_export_some_errors(patch_transformers=True) as modificator:
69+
inputs = modificator(inputs)
70+
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
71+
assert ep
72+
5673

5774
if __name__ == "__main__":
5875
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)