Skip to content

Commit 22f529d

Browse files
committed
fix issues
1 parent 8cb8ec5 commit 22f529d

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import unittest
23
import torch
34
from transformers.cache_utils import DynamicCache
@@ -22,18 +23,21 @@ def test_get_tiny_llm(self):
2223
def test_export_tiny_llm_1(self):
2324
data = get_tiny_llm()
2425
model, inputs = data["model"], data["inputs"]
26+
expected = model(**copy.deepcopy(inputs))
2527
self.assertEqual(
2628
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
2729
)
2830
ep = torch.export.export(
29-
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
31+
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=data["dynamic_shapes"]
3032
)
31-
assert ep
33+
got = ep.module()(**inputs)
34+
self.assertEqualArrayAny(expected, got)
3235

3336
@ignore_warnings(UserWarning)
3437
def test_export_tiny_llm_2_bypassed(self):
3538
data = get_tiny_llm()
3639
model, inputs = data["model"], data["inputs"]
40+
expected = model(**copy.deepcopy(inputs))
3741
self.assertEqual(
3842
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3943
)
@@ -45,7 +49,7 @@ def test_export_tiny_llm_2_bypassed(self):
4549
for k in patched_DynamicCache._PATCHES_:
4650
self.assertEqual(getattr(patched_DynamicCache, k), getattr(DynamicCache, k))
4751

48-
inputs = modificator(inputs)
52+
inputs = modificator(copy.deepcopy(inputs))
4953

5054
def debug():
5155
print("***", string_type(inputs, with_shape=True))
@@ -67,7 +71,8 @@ def debug():
6771
ep = torch.export.export(
6872
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
6973
)
70-
assert ep
74+
got = ep.module()(**inputs)
75+
self.assertEqualArrayAny(expected, got)
7176

7277

7378
if __name__ == "__main__":

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def test_bypass_onnx_export_tiny_llm_official_full(self):
9999
dynamic_shapes=ds,
100100
dynamo=True,
101101
optimize=True,
102+
report=True,
103+
verify=True,
102104
)
103105
# There are some discrepancies with torch==2.6
104106
if not has_torch("2.7"):

0 commit comments

Comments
 (0)