Skip to content

Commit 6963db2

Browse files
authored
Better documentation, implements generate_and_validate (#281)
* better api * mypy
1 parent 175a800 commit 6963db2

File tree

3 files changed

+181
-102
lines changed

3 files changed

+181
-102
lines changed

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 20 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -7,92 +7,13 @@
77
requires_transformers,
88
requires_torch,
99
)
10-
from onnx_diagnostic.helpers import max_diff, flatten_object
11-
from onnx_diagnostic.helpers.rt_helper import onnx_generate, make_empty_cache
12-
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
13-
from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch
10+
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
1411
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1512
from onnx_diagnostic.torch_export_patches import torch_export_patches
1613
from onnx_diagnostic.export.api import to_onnx
1714

1815

1916
class TestRtSession(ExtTestCase):
20-
def simple_generate_with_cache(
21-
self,
22-
model,
23-
input_ids: torch.Tensor,
24-
eos_token_id: int,
25-
session: InferenceSessionForTorch,
26-
max_new_tokens: int = 100,
27-
):
28-
# First call: prefill
29-
attention_mask = torch.ones(
30-
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
31-
)
32-
feeds = {
33-
**dict(zip(session.input_names[:2], [input_ids, attention_mask])),
34-
**make_empty_cache(
35-
input_ids.shape[0],
36-
session.input_names[2:],
37-
session.input_shapes[2:],
38-
session.input_types[2:],
39-
),
40-
}
41-
onnx_results = session.run(None, feeds)
42-
43-
outputs = model(input_ids, use_cache=True, attention_mask=attention_mask)
44-
45-
diff = max_diff(outputs, onnx_results)
46-
assert diff["abs"] <= 0.1, (
47-
f"Unexpected issue with {type(model)}\ndiff={diff}"
48-
f"\ninput_ids.shape={input_ids.shape}"
49-
f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}"
50-
f"\n got=\n"
51-
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
52-
f"feeds={self.string_type(feeds, with_shape=True, with_min_max=True)}"
53-
)
54-
55-
# Next calls: decode
56-
for iteration in range(max_new_tokens):
57-
next_token_logits = outputs.logits[:, -1, :]
58-
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
59-
if next_token_id.item() == eos_token_id:
60-
break
61-
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
62-
attention_mask = torch.ones(
63-
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
64-
)
65-
feeds = dict(
66-
zip(
67-
session.input_names,
68-
[
69-
t.detach()
70-
for t in torch_deepcopy(
71-
flatten_object(
72-
[next_token_id, attention_mask, outputs.past_key_values]
73-
)
74-
)
75-
],
76-
)
77-
)
78-
onnx_results = session.run(None, feeds)
79-
outputs = model(
80-
next_token_id,
81-
use_cache=True,
82-
past_key_values=outputs.past_key_values,
83-
attention_mask=attention_mask,
84-
)
85-
diff = max_diff(outputs, onnx_results)
86-
assert diff["abs"] <= 0.1, (
87-
f"Unexpected issue with {type(model)}, iteration={iteration}"
88-
f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
89-
f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}"
90-
f"\n got=\n"
91-
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
92-
f"feeds={self.string_type(feeds, with_shape=True, with_min_max=True)}"
93-
)
94-
return input_ids
95-
9617
@requires_transformers("4.55")
9718
@requires_torch("2.9")
9819
@hide_stdout()
@@ -118,25 +39,25 @@ def test_onnx_generate(self):
11839
exporter="custom",
11940
)
12041

121-
print("-- test_onnx_generate: generate")
122-
res, session = onnx_generate(
123-
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
124-
)
125-
n_inputs = input_ids.shape[1]
126-
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
127-
self.assertEqual(res.dtype, torch.int64)
128-
self.assertEqual(res.shape, (1, 13))
129-
print("-- test_onnx_generate: done")
130-
# expected = model.generate(input_ids[:1], max_new_tokens=10)
131-
expected = self.simple_generate_with_cache(
132-
model, input_ids[:1], 2, max_new_tokens=10, session=session
133-
)
134-
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
135-
print("******", res)
136-
print("******", expected)
137-
self.assertEqual(expected.dtype, torch.int64)
138-
self.assertEqual(expected.shape, (1, 13))
139-
self.assertEqualArray(expected, res)
42+
print("-- test_onnx_generate: generate")
43+
res, session = onnx_generate(
44+
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
45+
)
46+
n_inputs = input_ids.shape[1]
47+
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
48+
self.assertEqual(res.dtype, torch.int64)
49+
self.assertEqual(res.shape, (1, 13))
50+
print("-- test_onnx_generate: done")
51+
# expected = model.generate(input_ids[:1], max_new_tokens=10)
52+
expected, _ = generate_and_validate(
53+
model, input_ids[:1], 2, max_new_tokens=10, session=session
54+
)
55+
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
56+
print("******", res)
57+
print("******", expected)
58+
self.assertEqual(expected.dtype, torch.int64)
59+
self.assertEqual(expected.shape, (1, 13))
60+
self.assertEqualArray(expected, res)
14061

14162

14263
if __name__ == "__main__":

onnx_diagnostic/export/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def to_onnx(
3434
output_dynamic_shapes=output_dynamic_shapes,
3535
options=OptimizationOptions(patterns="default+onnxruntime"),
3636
)
37-
if exporter == "onnx-dynamo":
37+
if exporter in ("dynamo", "onnx-dynamo"):
3838
import onnxscript.rewriter.ort_fusions as ort_fusions
3939

4040
assert (

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, Dict, List, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Tuple, Union
22
import numpy as np
33
import onnx
44
import torch
5-
from .helper import string_type, flatten_object
5+
from .helper import string_type, flatten_object, max_diff
6+
from .torch_helper import torch_deepcopy
67
from .ort_session import InferenceSessionForTorch
78

89

@@ -147,6 +148,115 @@ def make_empty_cache(
147148
return feeds
148149

149150

151+
def generate_and_validate(
152+
model,
153+
input_ids: torch.Tensor,
154+
eos_token_id: int,
155+
max_new_tokens: int = 100,
156+
session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
157+
atol: float = 0.1,
158+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
159+
"""
160+
Implements a simple method ``generate`` for a torch model.
161+
The function does not expect any ``position_ids`` as input.
162+
The function also checks the outputs coming from an onnx model
163+
are close to the output the torch model produces.
164+
165+
:param model_or_path: model or loaded model
166+
:param input_ids: input tokens
167+
:param eos_token_ids: token representing the end of an answer
168+
:param max_new_tokens: stops after this number of generated tokens
169+
:param session: the onnx model
170+
:return: input tokens concatenated with new tokens,
171+
if session is not null, it also returns the maximum differences
172+
at every iterations
173+
174+
See example given with function :func:`onnx_generate
175+
<onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
176+
"""
177+
if session is not None:
178+
if not isinstance(session, InferenceSessionForTorch):
179+
providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
180+
providers.append("CPUExecutionProvider")
181+
session = InferenceSessionForTorch(session, providers=providers)
182+
183+
# First call: prefill
184+
attention_mask = torch.ones(
185+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
186+
)
187+
if session:
188+
feeds = {
189+
**dict(zip(session.input_names[:2], [input_ids, attention_mask])),
190+
**make_empty_cache(
191+
input_ids.shape[0],
192+
session.input_names[2:],
193+
session.input_shapes[2:],
194+
session.input_types[2:],
195+
),
196+
}
197+
onnx_results = session.run(None, feeds)
198+
199+
outputs = model(input_ids, use_cache=True, attention_mask=attention_mask)
200+
201+
if session:
202+
diff = max_diff(outputs, onnx_results)
203+
assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
204+
f"Unexpected issue with {type(model)}\ndiff={diff}"
205+
f"\ninput_ids.shape={input_ids.shape}"
206+
f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
207+
f"\n got=\n"
208+
f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
209+
f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
210+
)
211+
diffs = [diff]
212+
213+
# Next calls: decode
214+
for iteration in range(max_new_tokens):
215+
next_token_logits = outputs.logits[:, -1, :]
216+
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
217+
if next_token_id.item() == eos_token_id:
218+
break
219+
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
220+
attention_mask = torch.ones(
221+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
222+
)
223+
if session:
224+
feeds = dict(
225+
zip(
226+
session.input_names,
227+
[
228+
t.detach()
229+
for t in torch_deepcopy(
230+
flatten_object(
231+
[next_token_id, attention_mask, outputs.past_key_values]
232+
)
233+
)
234+
],
235+
)
236+
)
237+
onnx_results = session.run(None, feeds)
238+
outputs = model(
239+
next_token_id,
240+
use_cache=True,
241+
past_key_values=outputs.past_key_values,
242+
attention_mask=attention_mask,
243+
)
244+
if session:
245+
diff = max_diff(outputs, onnx_results)
246+
assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
247+
f"Unexpected issue with {type(model)}, iteration={iteration}"
248+
f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
249+
f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
250+
f"\n got=\n"
251+
f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
252+
f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
253+
)
254+
diffs.append(diff)
255+
if session:
256+
return input_ids, diffs
257+
return input_ids
258+
259+
150260
def onnx_generate(
151261
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
152262
input_ids: torch.Tensor,
@@ -167,6 +277,54 @@ def onnx_generate(
167277
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
168278
created if necessary
169279
:return: input tokens concatenated with new tokens
280+
281+
.. runpython::
282+
:showcode:
283+
284+
import os
285+
from onnx_diagnostic.helpers import string_type, string_diff
286+
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
287+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
288+
from onnx_diagnostic.torch_export_patches import torch_export_patches
289+
from onnx_diagnostic.export.api import to_onnx
290+
291+
mid = "arnir0/Tiny-LLM"
292+
print(f"-- get model for {mid!r}")
293+
data = get_untrained_model_with_inputs(mid)
294+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
295+
del inputs["position_ids"]
296+
del ds["position_ids"]
297+
input_ids = inputs["input_ids"]
298+
299+
print(f"-- input_ids={input_ids.shape}")
300+
print(f"-- inputs: {string_type(inputs, with_shape=True)}")
301+
print(f"-- dynamic_shapes: {string_type(ds)}")
302+
folder = "dump_test"
303+
os.makedirs(folder, exist_ok=True)
304+
model_name = os.path.join(folder, "model.onnx")
305+
print("-- test_onnx_generate: export model")
306+
with torch_export_patches(patch_transformers=True, patch_torch=False):
307+
to_onnx(
308+
model,
309+
(),
310+
kwargs=inputs,
311+
dynamic_shapes=ds,
312+
filename=model_name,
313+
exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
314+
)
315+
316+
print("-- onnx_generate")
317+
onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
318+
print("-- onnx output", onnx_outputs)
319+
320+
print("-- generate")
321+
torch_outputs, diffs = generate_and_validate(
322+
model, input_ids[:1], 2, max_new_tokens=10, session=model_name
323+
)
324+
print("-- torch output", torch_outputs)
325+
print("-- differences at each step:")
326+
for i, d in enumerate(diffs):
327+
print(f"iteration {i}: {string_diff(d)}")
170328
"""
171329
if not isinstance(model_or_path, InferenceSessionForTorch):
172330
providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []

0 commit comments

Comments
 (0)