Skip to content

Commit 8fc8151

Browse files
committed
add onnx_generate
1 parent 18e2df6 commit 8fc8151

File tree

12 files changed

+375
-24
lines changed

12 files changed

+375
-24
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.17
55
++++++
66

7+
* :pr:`276`: implements onnx_generate which implements method generate for an onnx model
78
* :pr:`275`: fixes function ``patched_vmap``
89

910
0.7.16

_doc/api/export/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.api
3+
==========================
4+
5+
.. automodule:: onnx_diagnostic.export.api
6+
:members:
7+
:no-undoc-members:

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ onnx_diagnostic.export
55
:maxdepth: 1
66
:caption: modules
77

8+
api
89
dynamic_shapes
910
shape_helper
1011
validate

_doc/technical/plot_generate.py

Lines changed: 101 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,47 @@
1414
epkg:`microsoft/Phi-1.5` is a small LLM. The example given
1515
"""
1616

17+
import os
1718
import time
19+
import sys
1820
import pandas
1921
from tqdm import tqdm
22+
import torch
23+
from transformers import AutoModelForCausalLM, AutoTokenizer
2024
from onnx_diagnostic.ext_test_case import unit_test_going
2125
from onnx_diagnostic.helpers import string_type
26+
from onnx_diagnostic.helpers.torch_helper import to_any, get_weight_type
27+
from onnx_diagnostic.helpers.rt_helper import onnx_generate
28+
from onnx_diagnostic.torch_export_patches import torch_export_patches
2229
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
23-
import torch
24-
from transformers import AutoModelForCausalLM, AutoTokenizer
30+
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config, task_from_id
31+
from onnx_diagnostic.tasks import random_input_kwargs
32+
from onnx_diagnostic.export.api import to_onnx
33+
2534

26-
device = "cuda" if torch.cuda.is_available else "cpu"
35+
device = "cuda" if torch.cuda.is_available() else "cpu"
2736
data = []
2837

2938
print("-- load the model...")
30-
# unit_test_going() returns True if UNITTEST_GOING is 1
3139
if unit_test_going():
40+
# unit_test_going() returns True if UNITTEST_GOING is 1
41+
# The example switches to a faster scenario.
3242
model_id = "arnir0/Tiny-LLM"
33-
model = get_untrained_model_with_inputs(model_id)["model"]
43+
data_export = get_untrained_model_with_inputs(model_id)
44+
model = data_export["model"]
45+
export_inputs = data_export["inputs"]
46+
export_shapes = data_export["dynamic_shapes"]
3447
tokenizer = AutoTokenizer.from_pretrained(model_id)
3548
else:
3649
model_id = "microsoft/phi-1_5"
3750
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
3851
tokenizer = AutoTokenizer.from_pretrained(model_id)
52+
config = get_pretrained_config(model_id)
53+
task = task = task_from_id(model_id)
54+
kwargs, fct = random_input_kwargs(config, task)
55+
res = fct(model, config, add_second_input=False, **kwargs)
56+
export_inputs = res["inputs"]
57+
export_shapes = res["dynamic_shapes"]
3958
model = model.to(device)
4059
print("-- done.")
4160

@@ -52,11 +71,11 @@
5271

5372
print("-- compute the answer...")
5473
begin = time.perf_counter()
55-
outputs = model.generate(**inputs, max_length=100)
74+
outputs = model.generate(**inputs, max_new_tokens=100)
5675
duration = time.perf_counter() - begin
5776
print(f"-- done in {duration}")
5877
data.append(dict(name="generate", duration=duration))
59-
print("output shape:", string_type(outputs, with_shape=True))
78+
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
6079
print("-- decode the answer...")
6180
text = tokenizer.batch_decode(outputs)[0]
6281
print("-- done.")
@@ -79,35 +98,29 @@
7998
def simple_generate_with_cache(
8099
model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
81100
):
82-
answer = []
83-
# First call.
101+
# First call: prefill
84102
outputs = model(input_ids, use_cache=True)
85-
next_token_logits = outputs.logits[:, -1, :]
86-
past_key_values = outputs.past_key_values
87103

88-
# Next calls.
104+
# Next calls: decode
89105
for _ in tqdm(list(range(max_new_tokens))):
106+
next_token_logits = outputs.logits[:, -1, :]
107+
past_key_values = outputs.past_key_values
108+
90109
# The most probable next token is chosen.
91110
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
92111
# But we could select it using a multinomial law
93112
# <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
94113
# <<< top_probs, top_indices = torch.topk(probs, top_k)
95114
# <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
96115

97-
# Let's add the predicted token to the answer.
98-
answer.append(next_token_id)
116+
if next_token_id.item() == eos_token_id:
117+
break
118+
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
99119

100120
# Feed only the new token, but with the cache
101121
outputs = model(next_token_id, use_cache=True, past_key_values=past_key_values)
102-
next_token_logits = outputs.logits[:, -1, :]
103-
past_key_values = outputs.past_key_values
104-
105-
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
106122

107-
if next_token_id.item() == eos_token_id:
108-
break
109-
110-
return torch.cat(answer, dim=1)
123+
return input_ids
111124

112125

113126
print("-- compute the answer with custom generate...")
@@ -120,12 +133,77 @@ def simple_generate_with_cache(
120133
data.append(dict(name="custom", duration=duration))
121134

122135
print("-- done.")
123-
print("output shape:", string_type(outputs, with_shape=True))
136+
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
137+
print("-- decode the answer...")
138+
text = tokenizer.batch_decode(outputs)[0]
139+
print("-- done.")
140+
print(text)
141+
142+
# %%
143+
# Method generate for onnx models
144+
# ===============================
145+
#
146+
# We first need to export the model into ONNX.
147+
#
148+
# ONNX Conversion
149+
# +++++++++++++++
150+
151+
if "position_ids" in export_inputs:
152+
del export_inputs["position_ids"]
153+
del export_shapes["position_ids"]
154+
dtype = get_weight_type(model)
155+
print("-- model dtype:", dtype)
156+
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
157+
exporter = "custom" if "custom" in sys.argv else "onnx-dynamo"
158+
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
159+
if not os.path.exists(model_name):
160+
# This step is slow so let's skip it if it was already done.
161+
print("-- conversion to ONNX.")
162+
begin = time.perf_counter()
163+
with torch_export_patches(patch_transformers=True):
164+
to_onnx(
165+
model,
166+
(),
167+
kwargs=to_any(export_inputs, device),
168+
dynamic_shapes=export_shapes,
169+
filename=model_name,
170+
verbose=1,
171+
exporter=exporter,
172+
)
173+
duration = time.perf_counter() - begin
174+
print(f"-- done in {duration}")
175+
176+
# %%
177+
# onnx_generate
178+
# +++++++++++++
179+
#
180+
# Then we can call method generate for two tokens.
181+
# This function is part of :epkg:`onnx_diagnostic` but follows the implementation
182+
# seen earlier for a torch model.
183+
# Let's ask first the function to return the session to avoid creating on the second call.
184+
185+
_res, session = onnx_generate(
186+
model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
187+
)
188+
189+
# And now the full answer.
190+
print("-- compute the answer with custom generate...")
191+
begin = time.perf_counter()
192+
outputs = onnx_generate(
193+
session, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
194+
)
195+
duration = time.perf_counter() - begin
196+
print(f"-- done in {duration}")
197+
data.append(dict(name="onnx", duration=duration))
198+
199+
print("-- done.")
200+
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
124201
print("-- decode the answer...")
125202
text = tokenizer.batch_decode(outputs)[0]
126203
print("-- done.")
127204
print(text)
128205

206+
129207
# %%
130208
# Plots
131209
# =====

_unittests/ut_export/test_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.export.api import to_onnx
5+
6+
7+
class TestValidate(ExtTestCase):
8+
@hide_stdout()
9+
def test_to_onnx(self):
10+
class Model(torch.nn.Module):
11+
def forward(self, x, y):
12+
return x + y
13+
14+
x = torch.randn((5, 6))
15+
y = torch.randn((1, 6))
16+
ds = ({0: "a", 1: "b"}, {1: "b"})
17+
to_onnx(
18+
Model(),
19+
(x, y),
20+
dynamic_shapes=ds,
21+
exporter="custom",
22+
filename=self.get_dump_file("custom.onnx"),
23+
)
24+
to_onnx(
25+
Model(),
26+
(x, y),
27+
dynamic_shapes=ds,
28+
exporter="onnx-dynamo",
29+
filename=self.get_dump_file("onnx-dynamo.onnx"),
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
unittest.main(verbosity=2)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import unittest
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
5+
from onnx_diagnostic.helpers.rt_helper import onnx_generate
6+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
7+
from onnx_diagnostic.torch_export_patches import torch_export_patches
8+
9+
10+
class TestRtSession(ExtTestCase):
11+
@hide_stdout()
12+
def test_onnx_generate(self):
13+
from experimental_experiment.torch_interpreter import to_onnx
14+
15+
mid = "arnir0/Tiny-LLM"
16+
print("-- test_onnx_generate: get model")
17+
data = get_untrained_model_with_inputs(mid)
18+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
19+
del inputs["position_ids"]
20+
del ds["position_ids"]
21+
input_ids = inputs["input_ids"]
22+
folder = self.get_dump_folder("test_onnx_generate")
23+
model_name = os.path.join(folder, "model.onnx")
24+
print("-- test_onnx_generate: export model")
25+
with torch_export_patches(patch_transformers=True, patch_torch=False):
26+
to_onnx(
27+
model,
28+
(),
29+
kwargs=inputs,
30+
dynamic_shapes=ds,
31+
filename=model_name,
32+
)
33+
34+
print("-- test_onnx_generate: generate")
35+
res = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
36+
self.assertEqual(res.dtype, torch.int64)
37+
self.assertEqual(res.shape, (1, 13))
38+
print("-- test_onnx_generate: done")
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main(verbosity=2)

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from onnx_diagnostic.helpers import max_diff, string_type
99
from onnx_diagnostic.helpers.torch_helper import (
1010
dummy_llm,
11+
get_weight_type,
1112
to_numpy,
1213
is_torchdynamo_exporting,
1314
model_statistics,
@@ -415,6 +416,11 @@ def test_to_tensor(self):
415416
c = to_tensor(proto)
416417
self.assertEqualArray(a, c)
417418

419+
def test_get_weight_type(self):
420+
model, _inputs = dummy_llm("LLM")
421+
dt = get_weight_type(model)
422+
self.assertEqual(torch.float32, dt)
423+
418424

419425
if __name__ == "__main__":
420426
unittest.main(verbosity=2)

onnx_diagnostic/export/api.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
2+
import torch
3+
4+
5+
def to_onnx(
6+
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
7+
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
8+
kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
9+
input_names: Optional[Sequence[str]] = None,
10+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
11+
verbose: int = 0,
12+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
13+
filename: Optional[str] = None,
14+
output_names: Optional[List[str]] = None,
15+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
16+
exporter: str = "onnx-dynamo",
17+
) -> Any:
18+
"""Common API for exporters."""
19+
if exporter == "custom":
20+
from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
21+
from experimental_experiment.xbuilder import OptimizationOptions
22+
23+
return _to_onnx(
24+
mod,
25+
args=args,
26+
kwargs=kwargs,
27+
input_names=input_names,
28+
output_names=output_names,
29+
target_opset=target_opset,
30+
verbose=verbose,
31+
filename=filename,
32+
dynamic_shapes=dynamic_shapes,
33+
large_model=True,
34+
output_dynamic_shapes=output_dynamic_shapes,
35+
options=OptimizationOptions(patterns="default+onnxruntime"),
36+
)
37+
if exporter == "onnx-dynamo":
38+
import onnxscript.rewriter.ort_fusions as ort_fusions
39+
40+
assert (
41+
not output_dynamic_shapes
42+
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
43+
epo = torch.onnx.export(
44+
mod,
45+
args=args,
46+
kwargs=kwargs,
47+
input_names=input_names,
48+
output_names=output_names,
49+
opset_version=target_opset,
50+
dynamic_shapes=dynamic_shapes,
51+
)
52+
ort_fusions.optimize_for_ort(epo.model)
53+
epo.save(filename)
54+
return epo
55+
56+
raise ValueError(f"Unknown exporter={exporter!r}")

onnx_diagnostic/helpers/ort_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def __init__(
135135
self.sess = sess
136136
self.input_names = [i.name for i in sess.get_inputs()]
137137
self.output_names = [i.name for i in sess.get_outputs()]
138+
self.input_shapes = [i.shape for i in sess.get_inputs()]
139+
self.output_shapes = [i.shape for i in sess.get_outputs()]
140+
self.input_types = [i.type for i in sess.get_inputs()]
141+
self.output_types = [i.type for i in sess.get_outputs()]
138142
self.torch = torch
139143
self.nvtx = nvtx
140144
self.run_options = onnxruntime.RunOptions()

0 commit comments

Comments
 (0)