Skip to content

Commit 7979496

Browse files
authored
Implements custom generate method to validate onnx models (#276)
* generate * add onnx_generate * speel * fix issue * disable * fix * doc
1 parent dc74137 commit 7979496

File tree

15 files changed

+506
-3
lines changed

15 files changed

+506
-3
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.17
55
++++++
66

7+
* :pr:`276`: implements onnx_generate which implements method generate for an onnx model
78
* :pr:`275`: fixes function ``patched_vmap``
89

910
0.7.16

_doc/api/export/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.api
3+
==========================
4+
5+
.. automodule:: onnx_diagnostic.export.api
6+
:members:
7+
:no-undoc-members:

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ onnx_diagnostic.export
55
:maxdepth: 1
66
:caption: modules
77

8+
api
89
dynamic_shapes
910
shape_helper
1011
validate

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def linkcode_resolve(domain, info):
277277
epkg_dictionary.update(
278278
{
279279
"arnir0/Tiny-LLM": "https://huggingface.co/arnir0/Tiny-LLM",
280+
"microsoft/Phi-1.5": "https://huggingface.co/microsoft/phi-1_5",
280281
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2",
281282
"microsoft/Phi-3.5-mini-instruct": "https://huggingface.co/microsoft/Phi-3.5-mini-instruct",
282283
"microsoft/Phi-3.5-vision-instruct": "https://huggingface.co/microsoft/Phi-3.5-vision-instruct",

_doc/technical/plot_generate.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
.. _l-plot-generate:
3+
4+
=================================
5+
From a LLM to processing a prompt
6+
=================================
7+
8+
Method ``generate`` generates the model answer for a given prompt.
9+
Let's implement our own to understand better how it works and
10+
then apply it to an ONNX model.
11+
12+
Example with Phi 1.5
13+
====================
14+
15+
epkg:`microsoft/Phi-1.5` is a small LLM. The example given
16+
"""
17+
18+
import os
19+
import time
20+
import sys
21+
import pandas
22+
from tqdm import tqdm
23+
import torch
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
25+
from onnx_diagnostic.ext_test_case import unit_test_going
26+
from onnx_diagnostic.helpers import string_type
27+
from onnx_diagnostic.helpers.torch_helper import to_any, get_weight_type
28+
from onnx_diagnostic.helpers.rt_helper import onnx_generate
29+
from onnx_diagnostic.torch_export_patches import torch_export_patches
30+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
31+
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config, task_from_id
32+
from onnx_diagnostic.tasks import random_input_kwargs
33+
from onnx_diagnostic.export.api import to_onnx
34+
35+
36+
device = "cuda" if torch.cuda.is_available() else "cpu"
37+
data = []
38+
39+
print("-- load the model...")
40+
if unit_test_going():
41+
# unit_test_going() returns True if UNITTEST_GOING is 1
42+
# The example switches to a faster scenario.
43+
model_id = "arnir0/Tiny-LLM"
44+
data_export = get_untrained_model_with_inputs(model_id)
45+
model = data_export["model"]
46+
export_inputs = data_export["inputs"]
47+
export_shapes = data_export["dynamic_shapes"]
48+
tokenizer = AutoTokenizer.from_pretrained(model_id)
49+
else:
50+
model_id = "microsoft/phi-1_5"
51+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
52+
tokenizer = AutoTokenizer.from_pretrained(model_id)
53+
config = get_pretrained_config(model_id)
54+
task = task = task_from_id(model_id)
55+
kwargs, fct = random_input_kwargs(config, task)
56+
res = fct(model, config, add_second_input=False, **kwargs)
57+
export_inputs = res["inputs"]
58+
export_shapes = res["dynamic_shapes"]
59+
model = model.to(device)
60+
print("-- done.")
61+
62+
print("-- tokenize the prompt...")
63+
inputs = tokenizer(
64+
'''def print_prime(n):
65+
"""
66+
Print all primes between 1 and n
67+
"""''',
68+
return_tensors="pt",
69+
return_attention_mask=False,
70+
).to(device)
71+
print("-- done.")
72+
73+
print("-- compute the answer...")
74+
begin = time.perf_counter()
75+
outputs = model.generate(**inputs, max_new_tokens=100)
76+
duration = time.perf_counter() - begin
77+
print(f"-- done in {duration}")
78+
data.append(dict(name="generate", duration=duration))
79+
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
80+
print("-- decode the answer...")
81+
text = tokenizer.batch_decode(outputs)[0]
82+
print("-- done.")
83+
print(text)
84+
85+
86+
# %%
87+
# eos_token_id?
88+
# =============
89+
#
90+
# This token means the end of the answer.
91+
92+
print("eos_token_id=", tokenizer.eos_token_id)
93+
94+
# %%
95+
# Custom method generate
96+
# ======================
97+
98+
99+
def simple_generate_with_cache(
100+
model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
101+
):
102+
# First call: prefill
103+
outputs = model(input_ids, use_cache=True)
104+
105+
# Next calls: decode
106+
for _ in tqdm(list(range(max_new_tokens))):
107+
next_token_logits = outputs.logits[:, -1, :]
108+
past_key_values = outputs.past_key_values
109+
110+
# The most probable next token is chosen.
111+
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
112+
# But we could select it using a multinomial law
113+
# <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
114+
# <<< top_probs, top_indices = torch.topk(probs, top_k)
115+
# <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
116+
117+
if next_token_id.item() == eos_token_id:
118+
break
119+
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
120+
121+
# Feed only the new token, but with the cache
122+
outputs = model(next_token_id, use_cache=True, past_key_values=past_key_values)
123+
124+
return input_ids
125+
126+
127+
print("-- compute the answer with custom generate...")
128+
begin = time.perf_counter()
129+
outputs = simple_generate_with_cache(
130+
model, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
131+
)
132+
duration = time.perf_counter() - begin
133+
print(f"-- done in {duration}")
134+
data.append(dict(name="custom", duration=duration))
135+
136+
print("-- done.")
137+
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
138+
print("-- decode the answer...")
139+
text = tokenizer.batch_decode(outputs)[0]
140+
print("-- done.")
141+
print(text)
142+
143+
# %%
144+
# Method generate for onnx models
145+
# ===============================
146+
#
147+
# We first need to export the model into ONNX.
148+
#
149+
# ONNX Conversion
150+
# +++++++++++++++
151+
152+
if "position_ids" in export_inputs:
153+
del export_inputs["position_ids"]
154+
del export_shapes["position_ids"]
155+
dtype = get_weight_type(model)
156+
print("-- model dtype:", dtype)
157+
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
158+
exporter = "custom" if "custom" in sys.argv else "onnx-dynamo"
159+
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
160+
if not os.path.exists(model_name):
161+
# This step is slow so let's skip it if it was already done.
162+
print("-- conversion to ONNX.")
163+
begin = time.perf_counter()
164+
with torch_export_patches(patch_transformers=True):
165+
to_onnx(
166+
model,
167+
(),
168+
kwargs=to_any(export_inputs, device),
169+
dynamic_shapes=export_shapes,
170+
filename=model_name,
171+
verbose=1,
172+
exporter=exporter,
173+
)
174+
duration = time.perf_counter() - begin
175+
print(f"-- done in {duration}")
176+
177+
# %%
178+
# onnx_generate
179+
# +++++++++++++
180+
#
181+
# Then we can call method generate for two tokens.
182+
# This function is part of :mod:`onnx_diagnostic` but follows the implementation
183+
# seen earlier for a torch model.
184+
# Let's ask first the function to return the session to avoid creating on the second call.
185+
186+
_res, session = onnx_generate(
187+
model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
188+
)
189+
190+
# And now the full answer.
191+
print("-- compute the answer with custom generate...")
192+
begin = time.perf_counter()
193+
outputs = onnx_generate(
194+
session, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
195+
)
196+
duration = time.perf_counter() - begin
197+
print(f"-- done in {duration}")
198+
data.append(dict(name="onnx", duration=duration))
199+
200+
print("-- done.")
201+
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
202+
print("-- decode the answer...")
203+
text = tokenizer.batch_decode(outputs)[0]
204+
print("-- done.")
205+
print(text)
206+
207+
208+
# %%
209+
# Plots
210+
# =====
211+
df = pandas.DataFrame(data).set_index("name")
212+
print(df)
213+
214+
# %%
215+
ax = df.plot(kind="bar", title="Time (s) comparison to generate a prompt.", rot=45)
216+
ax.figure.tight_layout()
217+
ax.figure.savefig("plot_generate.png")

_unittests/ut_export/test_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.export.api import to_onnx
5+
6+
7+
class TestValidate(ExtTestCase):
8+
@hide_stdout()
9+
def test_to_onnx(self):
10+
class Model(torch.nn.Module):
11+
def forward(self, x, y):
12+
return x + y
13+
14+
x = torch.randn((5, 6))
15+
y = torch.randn((1, 6))
16+
ds = ({0: "a", 1: "b"}, {1: "b"})
17+
to_onnx(
18+
Model(),
19+
(x, y),
20+
dynamic_shapes=ds,
21+
exporter="custom",
22+
filename=self.get_dump_file("custom.onnx"),
23+
)
24+
to_onnx(
25+
Model(),
26+
(x, y),
27+
dynamic_shapes=ds,
28+
exporter="onnx-dynamo",
29+
filename=self.get_dump_file("onnx-dynamo.onnx"),
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
unittest.main(verbosity=2)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import unittest
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
5+
from onnx_diagnostic.helpers.rt_helper import onnx_generate
6+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
7+
from onnx_diagnostic.torch_export_patches import torch_export_patches
8+
9+
10+
class TestRtSession(ExtTestCase):
11+
@hide_stdout()
12+
def test_onnx_generate(self):
13+
from experimental_experiment.torch_interpreter import to_onnx
14+
15+
mid = "arnir0/Tiny-LLM"
16+
print("-- test_onnx_generate: get model")
17+
data = get_untrained_model_with_inputs(mid)
18+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
19+
del inputs["position_ids"]
20+
del ds["position_ids"]
21+
input_ids = inputs["input_ids"]
22+
folder = self.get_dump_folder("test_onnx_generate")
23+
model_name = os.path.join(folder, "model.onnx")
24+
print("-- test_onnx_generate: export model")
25+
with torch_export_patches(patch_transformers=True, patch_torch=False):
26+
to_onnx(
27+
model,
28+
(),
29+
kwargs=inputs,
30+
dynamic_shapes=ds,
31+
filename=model_name,
32+
)
33+
34+
print("-- test_onnx_generate: generate")
35+
res = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
36+
self.assertEqual(res.dtype, torch.int64)
37+
self.assertEqual(res.shape, (1, 13))
38+
print("-- test_onnx_generate: done")
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main(verbosity=2)

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from onnx_diagnostic.helpers import max_diff, string_type
99
from onnx_diagnostic.helpers.torch_helper import (
1010
dummy_llm,
11+
get_weight_type,
1112
to_numpy,
1213
is_torchdynamo_exporting,
1314
model_statistics,
@@ -415,6 +416,11 @@ def test_to_tensor(self):
415416
c = to_tensor(proto)
416417
self.assertEqualArray(a, c)
417418

419+
def test_get_weight_type(self):
420+
model, _inputs = dummy_llm("LLM")
421+
dt = get_weight_type(model)
422+
self.assertEqual(torch.float32, dt)
423+
418424

419425
if __name__ == "__main__":
420426
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_documentation_technical.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import time
77
import torch
88
from onnx_diagnostic import __file__ as onnx_diagnostic_file
9-
from onnx_diagnostic.ext_test_case import ExtTestCase, is_windows, ignore_errors
9+
from onnx_diagnostic.ext_test_case import (
10+
ExtTestCase,
11+
is_windows,
12+
ignore_errors,
13+
has_transformers,
14+
)
1015

1116

1217
VERBOSE = 0
@@ -80,6 +85,9 @@ def add_test_methods(cls):
8085
if not reason and torch.__version__.startswith("2.9.0"):
8186
reason = "examples are failing for on CI for 2.9.0"
8287

88+
if not reason and not has_transformers("4.55.0") and name in {"plot_generate.py"}:
89+
reason = "transformers 4.55 is required"
90+
8391
if reason:
8492

8593
@unittest.skip(reason)

0 commit comments

Comments
 (0)