Skip to content

Commit 8592cd6

Browse files
authored
More unit tests (#171)
* more unit test * Fix static cache * fix issues * fix issue
1 parent 6eb85b7 commit 8592cd6

File tree

6 files changed

+59
-4
lines changed

6 files changed

+59
-4
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
5+
6+
7+
class TestDynamicShapes(ExtTestCase):
8+
def test_getitem_index_put1(self):
9+
class Model(torch.nn.Module):
10+
def forward(self, x, value):
11+
x = x.clone()
12+
x[:, :, :, : value.shape[-1]] = value
13+
return x
14+
15+
inputs = (torch.randn(2, 2, 3, 4), torch.randn(2, 2, 3, 3))
16+
model = Model()
17+
expected = model(*inputs)
18+
19+
onx = self.to_onnx(model, inputs, dynamic_shapes=({3: "M"}, {3: "N"}))
20+
self.dump_onnx("test_getitem_index_put1.onnx", onx)
21+
feeds = dict(zip(["x", "value"], [x.detach().cpu().numpy() for x in inputs]))
22+
ref = ExtendedReferenceEvaluator(onx, verbose=0)
23+
got = ref.run(None, feeds)[0]
24+
self.assertEqualArray(expected, got, atol=1e-5)
25+
sess = self.ort().InferenceSession(
26+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
27+
)
28+
got = sess.run(None, feeds)[0]
29+
self.assertEqualArray(expected, got, atol=1e-5)
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_tiny_llm_export_static(self):
5656
self.assertEqual(
5757
{"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs)
5858
)
59-
with torch_export_patches(patch_transformers=True, stop_if_static=1):
59+
with torch_export_patches(patch_transformers=True, stop_if_static=0):
6060
ep = torch.export.export(
6161
model,
6262
(),

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ def add_test_methods(cls):
109109
):
110110
reason = "torch<2.8"
111111

112+
if (
113+
not reason
114+
and name in {"plot_dump_intermediate_results.py"}
115+
and not has_torch("2.9.1")
116+
):
117+
reason = "unstable, let's wait for the next version"
118+
112119
if reason:
113120

114121
@unittest.skip(reason)

onnx_diagnostic/ext_test_case.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,18 @@ def todo(cls, f: Callable, msg: str):
756756
"Adds a todo printed when all test are run."
757757
cls._todos.append((f, msg))
758758

759+
@classmethod
760+
def ort(cls):
761+
import onnxruntime
762+
763+
return onnxruntime
764+
765+
@classmethod
766+
def to_onnx(self, *args, **kwargs):
767+
from experimental_experiment.torch_interpreter import to_onnx
768+
769+
return to_onnx(*args, **kwargs)
770+
759771
def print_model(self, model: "ModelProto"): # noqa: F821
760772
"Prints a ModelProto"
761773
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def make_static_cache(
181181
torch.randn(bsize, nheads, slen, dim),
182182
)
183183
for i in range(n_layers)
184-
]
184+
],
185+
max_cache_len=10,
185186
)
186187
print(string_type(past_key_values, with_shape=True))
187188
"""

onnx_diagnostic/tasks/text_generation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def get_inputs(
176176
"attention_mask": {0: batch, 2: "seq"},
177177
"cache_position": {0: "seq"},
178178
"past_key_values": [
179-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
179+
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180+
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181+
[{0: batch} for _ in range(num_hidden_layers)],
182+
[{0: batch} for _ in range(num_hidden_layers)],
181183
],
182184
}
183185
inputs = dict(

0 commit comments

Comments
 (0)