Skip to content

Commit e89379d

Browse files
committed
fix some issues
1 parent 1043a16 commit e89379d

File tree

3 files changed

+108
-68
lines changed

3 files changed

+108
-68
lines changed

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import os
22
import unittest
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
hide_stdout,
7+
requires_transformers,
8+
requires_torch,
9+
)
510
from onnx_diagnostic.helpers import max_diff, flatten_object
6-
from onnx_diagnostic.helpers.rt_helper import onnx_generate
11+
from onnx_diagnostic.helpers.rt_helper import onnx_generate, make_empty_cache
712
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
813
from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch
914
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
@@ -21,16 +26,33 @@ def simple_generate_with_cache(
2126
max_new_tokens: int = 100,
2227
):
2328
# First call: prefill
24-
outputs = model(
25-
input_ids,
26-
use_cache=True,
27-
attention_mask=torch.ones(
28-
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
29+
attention_mask = torch.ones(
30+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
31+
)
32+
feeds = {
33+
**dict(zip(session.input_names[:2], [input_ids, attention_mask])),
34+
**make_empty_cache(
35+
input_ids.shape[0],
36+
session.input_names[2:],
37+
session.input_shapes[2:],
38+
session.input_types[2:],
2939
),
40+
}
41+
onnx_results = session.run(None, feeds)
42+
43+
outputs = model(input_ids, use_cache=True, attention_mask=attention_mask)
44+
45+
diff = max_diff(outputs, onnx_results)
46+
assert diff["abs"] <= 0.1, (
47+
f"Unexpected issue with {type(model)}\ndiff={diff}"
48+
f"\ninput_ids.shape={input_ids.shape}"
49+
f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}"
50+
f"\n got=\n"
51+
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}"
3052
)
3153

3254
# Next calls: decode
33-
for _ in range(max_new_tokens):
55+
for iteration in range(max_new_tokens):
3456
next_token_logits = outputs.logits[:, -1, :]
3557
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
3658
if next_token_id.item() == eos_token_id:
@@ -42,11 +64,14 @@ def simple_generate_with_cache(
4264
feeds = dict(
4365
zip(
4466
session.input_names,
45-
torch_deepcopy(
46-
flatten_object(
47-
[next_token_id, attention_mask, outputs.past_key_values]
67+
[
68+
t.detach()
69+
for t in torch_deepcopy(
70+
flatten_object(
71+
[next_token_id, attention_mask, outputs.past_key_values]
72+
)
4873
)
49-
),
74+
],
5075
)
5176
)
5277
onnx_results = session.run(None, feeds)
@@ -57,9 +82,17 @@ def simple_generate_with_cache(
5782
attention_mask=attention_mask,
5883
)
5984
diff = max_diff(outputs, onnx_results)
60-
print("****", diff)
85+
assert diff["abs"] <= 0.1, (
86+
f"Unexpected issue with {type(model)}, iteration={iteration}"
87+
f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
88+
f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}"
89+
f"\n got=\n"
90+
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}"
91+
)
6192
return input_ids
6293

94+
@requires_transformers("4.55")
95+
@requires_torch("2.9")
6396
@hide_stdout()
6497
def test_onnx_generate(self):
6598
mid = "arnir0/Tiny-LLM"
@@ -83,25 +116,25 @@ def test_onnx_generate(self):
83116
exporter="custom",
84117
)
85118

86-
print("-- test_onnx_generate: generate")
87-
res, session = onnx_generate(
88-
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
89-
)
90-
n_inputs = input_ids.shape[1]
91-
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
92-
self.assertEqual(res.dtype, torch.int64)
93-
self.assertEqual(res.shape, (1, 13))
94-
print("-- test_onnx_generate: done")
95-
# expected = model.generate(input_ids[:1], max_new_tokens=10)
96-
expected = self.simple_generate_with_cache(
97-
model, input_ids[:1], 2, max_new_tokens=10, session=session
98-
)
99-
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
100-
print("******", res)
101-
print("******", expected)
102-
self.assertEqual(expected.dtype, torch.int64)
103-
self.assertEqual(expected.shape, (1, 13))
104-
self.assertEqualArray(expected, res)
119+
print("-- test_onnx_generate: generate")
120+
res, session = onnx_generate(
121+
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
122+
)
123+
n_inputs = input_ids.shape[1]
124+
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
125+
self.assertEqual(res.dtype, torch.int64)
126+
self.assertEqual(res.shape, (1, 13))
127+
print("-- test_onnx_generate: done")
128+
# expected = model.generate(input_ids[:1], max_new_tokens=10)
129+
expected = self.simple_generate_with_cache(
130+
model, input_ids[:1], 2, max_new_tokens=10, session=session
131+
)
132+
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
133+
print("******", res)
134+
print("******", expected)
135+
self.assertEqual(expected.dtype, torch.int64)
136+
self.assertEqual(expected.shape, (1, 13))
137+
self.assertEqualArray(expected, res)
105138

106139

107140
if __name__ == "__main__":

onnx_diagnostic/helpers/helper.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,20 @@ def max_diff(
10571057
allow_unique_tensor_with_list_of_one_element=False,
10581058
hist=hist,
10591059
)
1060+
1061+
if expected.__class__.__name__ == "CausalLMOutputWithPast":
1062+
if verbose >= 6:
1063+
print(
1064+
f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
1065+
f"? {string_type(got)}"
1066+
)
1067+
return max_diff(
1068+
[expected.logits, *flatten_object(expected.past_key_values)],
1069+
got,
1070+
debug_info=_debug(expected.__class__.__name__),
1071+
**_dkws,
1072+
)
1073+
10601074
if hasattr(expected, "to_tuple"):
10611075
if verbose >= 6:
10621076
print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}")
@@ -1067,36 +1081,6 @@ def max_diff(
10671081
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
10681082
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)
10691083

1070-
if isinstance(got, (list, tuple)):
1071-
if len(got) != 1:
1072-
if verbose >= 6:
1073-
print(
1074-
f"[max_diff] list,tuple,2: {string_type(expected)} "
1075-
f"? {string_type(got)}"
1076-
)
1077-
if verbose > 2:
1078-
import torch
1079-
1080-
print(
1081-
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
1082-
f"len(got)={len(got)}, level={level}, _index={_index}"
1083-
)
1084-
for i, (a, b) in enumerate(zip(expected, got)):
1085-
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
1086-
print(
1087-
f" i={i} expected {a.dtype}:{a.shape}, "
1088-
f"has {b.dtype}:{b.shape}, _index={_index}"
1089-
)
1090-
else:
1091-
print(
1092-
f" i={i} a is {type(a)}, "
1093-
f"b is {type(b)}, _index={_index}"
1094-
)
1095-
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1096-
if verbose >= 6:
1097-
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
1098-
return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws)
1099-
11001084
if isinstance(expected, (tuple, list)):
11011085
if verbose >= 6:
11021086
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
@@ -1485,7 +1469,7 @@ def max_diff(
14851469
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
14861470
if verbose >= 6:
14871471
print(
1488-
f"[max_diff] {expected.__class__.__name__}: "
1472+
f"[max_diff*] {expected.__class__.__name__}: "
14891473
f"{string_type(expected)} ? {string_type(got)}"
14901474
)
14911475
expected_args, _spec = torch.utils._pytree.tree_flatten(expected)

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,31 @@ def rt_type_to_torch_dtype(typename: str) -> torch.dtype:
122122
return _DTYPES[typename]
123123

124124

125+
def make_empty_cache(
126+
batch: int,
127+
onnx_input_names: List[str],
128+
onnx_input_shapes: List[Tuple[Union[int, str], ...]],
129+
onnx_input_types: List[str],
130+
) -> Dict[str, torch.Tensor]:
131+
"""
132+
Creates an empty cache. Example:
133+
134+
.. code-block:: python
135+
136+
make_empty_cache(
137+
1,
138+
sess.input_names[2:],
139+
[i.shape for i in sess.get_inputs()[2:]],
140+
[i.type for i in sess.get_inputs()[2:]],
141+
)
142+
"""
143+
feeds = {}
144+
for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
145+
new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
146+
feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
147+
return feeds
148+
149+
125150
def onnx_generate(
126151
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
127152
input_ids: torch.Tensor,
@@ -166,12 +191,10 @@ def onnx_generate(
166191
attention_mask=torch.ones(
167192
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
168193
),
194+
**make_empty_cache(
195+
input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
196+
),
169197
)
170-
for name, shape, dtype in zip(input_names[2:], input_shapes[2:], input_types[2:]):
171-
new_shape = tuple(
172-
_get_dim(i, s, batch=input_ids.shape[0]) for i, s in enumerate(shape)
173-
)
174-
feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
175198

176199
outputs = session.run(None, feeds)
177200

0 commit comments

Comments
 (0)