Skip to content

Commit 1043a16

Browse files
committed
disable two ewemples
1 parent 38fe3d6 commit 1043a16

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,62 @@
22
import unittest
33
import torch
44
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
5+
from onnx_diagnostic.helpers import max_diff, flatten_object
56
from onnx_diagnostic.helpers.rt_helper import onnx_generate
7+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
8+
from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch
69
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
710
from onnx_diagnostic.torch_export_patches import torch_export_patches
811
from onnx_diagnostic.export.api import to_onnx
912

1013

1114
class TestRtSession(ExtTestCase):
1215
def simple_generate_with_cache(
13-
self, model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
16+
self,
17+
model,
18+
input_ids: torch.Tensor,
19+
eos_token_id: int,
20+
session: InferenceSessionForTorch,
21+
max_new_tokens: int = 100,
1422
):
1523
# First call: prefill
1624
outputs = model(
1725
input_ids,
26+
use_cache=True,
1827
attention_mask=torch.ones(
1928
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
2029
),
21-
use_cache=True,
2230
)
2331

2432
# Next calls: decode
2533
for _ in range(max_new_tokens):
2634
next_token_logits = outputs.logits[:, -1, :]
27-
past_key_values = outputs.past_key_values
2835
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
2936
if next_token_id.item() == eos_token_id:
3037
break
3138
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
39+
attention_mask = torch.ones(
40+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
41+
)
42+
feeds = dict(
43+
zip(
44+
session.input_names,
45+
torch_deepcopy(
46+
flatten_object(
47+
[next_token_id, attention_mask, outputs.past_key_values]
48+
)
49+
),
50+
)
51+
)
52+
onnx_results = session.run(None, feeds)
3253
outputs = model(
3354
next_token_id,
3455
use_cache=True,
35-
past_key_values=past_key_values,
36-
attention_mask=torch.ones(
37-
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
38-
),
56+
past_key_values=outputs.past_key_values,
57+
attention_mask=attention_mask,
3958
)
59+
diff = max_diff(outputs, onnx_results)
60+
print("****", diff)
4061
return input_ids
4162

4263
@hide_stdout()
@@ -63,14 +84,18 @@ def test_onnx_generate(self):
6384
)
6485

6586
print("-- test_onnx_generate: generate")
66-
res = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
87+
res, session = onnx_generate(
88+
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
89+
)
6790
n_inputs = input_ids.shape[1]
6891
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
6992
self.assertEqual(res.dtype, torch.int64)
7093
self.assertEqual(res.shape, (1, 13))
7194
print("-- test_onnx_generate: done")
7295
# expected = model.generate(input_ids[:1], max_new_tokens=10)
73-
expected = self.simple_generate_with_cache(model, input_ids[:1], 2, max_new_tokens=10)
96+
expected = self.simple_generate_with_cache(
97+
model, input_ids[:1], 2, max_new_tokens=10, session=session
98+
)
7499
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
75100
print("******", res)
76101
print("******", expected)

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def add_test_methods(cls):
8484
if not reason and not has_dot and name in {"plot_dump_intermediate_results.py"}:
8585
reason = "dot not installed"
8686

87+
# transformers
88+
8789
if (
8890
not reason
8991
and name in {"plot_export_tiny_llm.py"}
@@ -98,13 +100,23 @@ def add_test_methods(cls):
98100
):
99101
reason = "transformers<4.52"
100102

103+
if (
104+
not reason
105+
and name in {"plot_export_with_dynamic_cache.py", "plot_export_tiny_phi2.py"}
106+
and not has_transformers("4.55")
107+
):
108+
reason = "transformers<4.55"
109+
110+
# pytorch
111+
101112
if (
102113
not reason
103114
and name
104115
in {
116+
"plot_export_hub_codellama.py",
105117
"plot_export_locate_issue.py",
106118
"plot_export_with_auto.py",
107-
"plot_export_hub_codellama.py",
119+
"plot_export_tiny_llm.py",
108120
}
109121
and not has_torch("2.8")
110122
):
@@ -117,13 +129,6 @@ def add_test_methods(cls):
117129
):
118130
reason = "unstable, let's wait for the next version"
119131

120-
if (
121-
not reason
122-
and name in {"plot_export_tiny_phi2.py"}
123-
and not has_transformers("4.55")
124-
):
125-
reason = "unstable, let's wait for the next version"
126-
127132
if not reason and name in {
128133
"plot_export_tiny_llm_dim01.py",
129134
"plot_export_tiny_llm_dim01_onnx.py",

0 commit comments

Comments
 (0)