Skip to content

Commit 143deef

Browse files
authored
Adds prompt to test speedup with onnx_generate (#287)
* Adds prompt to test speedup with onnx_generate * prompt * fix * try * doc
1 parent 6ca5f73 commit 143deef

File tree

8 files changed

+145
-65
lines changed

8 files changed

+145
-65
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.1
55
+++++
66

7+
* :pr:`287`: adds input ``'inputs_prompt'`` to test a LLM, meant to be used during validation
78
* :pr:`288`: add .contiguous in torch.cond branch (attention patch for sdpa implementation)
89
* :pr:`286`: adds variable to track random nodes in models
910

_doc/technical/plot_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def simple_generate_with_cache(
186186
# seen earlier for a torch model.
187187
# Let's ask first the function to return the session to avoid creating on the second call.
188188

189-
_res, session = onnx_generate(
189+
_res, session, _feeds = onnx_generate(
190190
model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
191191
)
192192

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_onnx_generate(self):
4848
)
4949

5050
print("-- test_onnx_generate: generate")
51-
res, session = onnx_generate(
51+
res, session, _feeds = onnx_generate(
5252
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
5353
)
5454
n_inputs = input_ids.shape[1]

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
hide_stdout,
66
requires_transformers,
77
requires_torch,
8+
ignore_warnings,
89
)
910
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1011
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1112
from onnx_diagnostic.torch_export_patches import torch_export_patches
1213
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1314
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
15+
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
1416

1517

1618
class TestTasksTextGeneration(ExtTestCase):
@@ -75,6 +77,26 @@ def test_text_generation_tiny_llm(self):
7577
self.assertEqualAny(expected.past_key_values, got.past_key_values)
7678
self.assertEqualArray(expected.logits, got.logits)
7779

80+
@hide_stdout()
81+
@requires_transformers("4.53")
82+
@requires_torch("2.8.99") # check_guards not supported
83+
@ignore_warnings(FutureWarning)
84+
def test_text_generation_tiny_llm_prompt_validation(self):
85+
from experimental_experiment.torch_interpreter import to_onnx
86+
87+
mid = "arnir0/Tiny-LLM"
88+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
89+
prompt = data["inputs_prompt"]["input_ids"]
90+
self.assertEqual(data["task"], "text-generation")
91+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
92+
with torch_export_patches(patch_transformers=True, verbose=1, patch_torch=False):
93+
onx = to_onnx(model, inputs, dynamic_shapes=ds)
94+
95+
self.dump_onnx("test_text_generation_tiny_llm_prompt_validation.onnx", onx)
96+
onnx_sequence = onnx_generate(onx, prompt, max_new_tokens=3)
97+
torch_sequence = generate_and_validate(model, prompt, max_new_tokens=3)
98+
self.assertEqualArray(torch_sequence, onnx_sequence)
99+
78100

79101
if __name__ == "__main__":
80102
unittest.main(verbosity=2)

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def make_empty_cache(
149149
def generate_and_validate(
150150
model,
151151
input_ids: torch.Tensor,
152-
eos_token_id: int,
152+
eos_token_id: int = 2,
153153
max_new_tokens: int = 100,
154154
session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
155155
atol: float = 0.1,
@@ -258,10 +258,10 @@ def generate_and_validate(
258258
def onnx_generate(
259259
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
260260
input_ids: torch.Tensor,
261-
eos_token_id: int,
261+
eos_token_id: int = 2,
262262
max_new_tokens=100,
263263
return_session: bool = False,
264-
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
264+
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch, Dict[str, Any]]]:
265265
"""
266266
Implements a simple method ``generate`` for an ONNX model.
267267
The function does not expect any ``position_ids`` as input.
@@ -273,7 +273,7 @@ def onnx_generate(
273273
:param return_session: returns the instance of class
274274
:class:`InferenceSessionForTorch
275275
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
276-
created if necessary
276+
created if necessary, the function returns the feeds for the next iteration
277277
:return: input tokens concatenated with new tokens
278278
279279
.. runpython::
@@ -349,12 +349,19 @@ def onnx_generate(
349349
input_shapes = session.input_shapes
350350
input_names = session.input_names
351351
input_types = session.input_types
352+
has_position_ids = "position_ids" in session.input_names
352353

353354
assert (
354355
len(input_names) > 2
355356
and input_names[:2] == ["input_ids", "attention_mask"]
356-
and input_names[2].startswith("past_key_values")
357-
), f"Only text generation is supported but input_names == {input_names}"
357+
and input_names[3 if has_position_ids else 2].startswith("past_key_values")
358+
), (
359+
f"Only text generation is supported but input_names == {input_names}, "
360+
f"has_position_ids={has_position_ids}"
361+
)
362+
assert (
363+
not has_position_ids or input_names[2] == "position_ids"
364+
), f"position_ids must the third input but input_names={input_names}"
358365

359366
# First call: prefill
360367
feeds = dict(
@@ -366,6 +373,10 @@ def onnx_generate(
366373
input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
367374
),
368375
)
376+
if has_position_ids:
377+
feeds["position_ids"] = torch.unsqueeze(
378+
torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
379+
)
369380

370381
outputs = session.run(None, feeds)
371382

@@ -389,11 +400,21 @@ def onnx_generate(
389400
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
390401
),
391402
)
392-
feeds.update(dict(zip(input_names[2:], outputs[1:])))
403+
if has_position_ids:
404+
feeds["position_ids"] = torch.unsqueeze(
405+
torch.arange(
406+
input_ids.shape[1],
407+
input_ids.shape[1] + 1,
408+
dtype=torch.int64,
409+
device=input_ids.device,
410+
),
411+
0,
412+
)
413+
feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
393414
outputs = session.run(None, feeds)
394415

395416
if return_session:
396-
return input_ids, session
417+
return input_ids, session, feeds
397418
return input_ids
398419

399420

onnx_diagnostic/tasks/text_generation.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,74 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
5656
return kwargs
5757

5858

59+
def _get_input_falcon_mamba(
60+
model: torch.nn.Module,
61+
config: Optional[Any],
62+
dummy_max_token_id: int,
63+
num_hidden_layers: int,
64+
batch_size: int = 2,
65+
sequence_length: int = 30,
66+
sequence_length2: int = 3,
67+
dynamic_rope: bool = False,
68+
num_key_value_heads: Optional[int] = None,
69+
head_dim: Optional[int] = None,
70+
cls_cache: Optional[Union[type, str]] = None,
71+
**kwargs, # unused
72+
):
73+
try:
74+
from transformers.models.mamba.modeling_mamba import MambaCache
75+
except ImportError:
76+
from transformers.cache_utils import MambaCache
77+
78+
assert cls_cache in (
79+
"MambaCache",
80+
MambaCache,
81+
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
82+
83+
batch = "batch"
84+
seq_length_multiple = 8
85+
sequence_length = (
86+
(sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple
87+
)
88+
# sequence_inc = seq_length_multiple
89+
sequence_length2 = seq_length_multiple
90+
91+
shapes = {
92+
"input_ids": {0: batch, 1: "sequence_length"},
93+
"attention_mask": {
94+
0: batch,
95+
1: "cache+seq", # cache_length + seq_length
96+
},
97+
"cache_position": {
98+
0: batch,
99+
1: "cache+seq", # cache_length + seq_length
100+
},
101+
"cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
102+
}
103+
inputs = dict(
104+
input_ids=torch.randint(
105+
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
106+
).to(torch.int64),
107+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
108+
torch.int64
109+
),
110+
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
111+
# .expand((batch_size, -1))
112+
cache_params=make_mamba_cache(
113+
[
114+
(
115+
torch.randn(
116+
batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
117+
),
118+
torch.randn(batch_size, kwargs["intermediate_size"], kwargs["state_size"]),
119+
)
120+
for i in range(num_hidden_layers)
121+
]
122+
),
123+
)
124+
return dict(inputs=inputs, dynamic_shapes=shapes)
125+
126+
59127
def get_inputs(
60128
model: torch.nn.Module,
61129
config: Optional[Any],
@@ -68,7 +136,7 @@ def get_inputs(
68136
num_key_value_heads: Optional[int] = None,
69137
head_dim: Optional[int] = None,
70138
cls_cache: Optional[Union[type, str]] = None,
71-
add_second_input: int = 1,
139+
add_second_input: Optional[int] = None,
72140
**kwargs, # unused
73141
):
74142
"""
@@ -84,67 +152,28 @@ def get_inputs(
84152
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
85153
:param cls_cache: cache class, by default it is
86154
:class:`transformers.cache_utils.DynamicCache`
155+
:param add_second_input: adds other kinds of inputs
87156
:return: dictionary
88157
"""
89158
batch = "batch"
90159
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
91160
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
92161

93162
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
94-
try:
95-
from transformers.models.mamba.modeling_mamba import MambaCache
96-
except ImportError:
97-
from transformers.cache_utils import MambaCache
98-
99-
assert cls_cache in (
100-
"MambaCache",
101-
MambaCache,
102-
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
103-
seq_length_multiple = 8
104-
sequence_length = (
105-
(sequence_length + seq_length_multiple)
106-
// seq_length_multiple
107-
* seq_length_multiple
108-
)
109-
# sequence_inc = seq_length_multiple
110-
sequence_length2 = seq_length_multiple
111-
112-
shapes = {
113-
"input_ids": {0: batch, 1: "sequence_length"},
114-
"attention_mask": {
115-
0: batch,
116-
1: "cache+seq", # cache_length + seq_length
117-
},
118-
"cache_position": {
119-
0: batch,
120-
1: "cache+seq", # cache_length + seq_length
121-
},
122-
"cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
123-
}
124-
inputs = dict(
125-
input_ids=torch.randint(
126-
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
127-
).to(torch.int64),
128-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
129-
torch.int64
130-
),
131-
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
132-
# .expand((batch_size, -1))
133-
cache_params=make_mamba_cache(
134-
[
135-
(
136-
torch.randn(
137-
batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
138-
),
139-
torch.randn(
140-
batch_size, kwargs["intermediate_size"], kwargs["state_size"]
141-
),
142-
)
143-
for i in range(num_hidden_layers)
144-
]
145-
),
163+
res = _get_input_falcon_mamba(
164+
model=model,
165+
config=config,
166+
dummy_max_token_id=dummy_max_token_id,
167+
num_hidden_layers=num_hidden_layers,
168+
batch_size=batch_size,
169+
sequence_length=sequence_length,
170+
sequence_length2=sequence_length2,
171+
dynamic_rope=dynamic_rope,
172+
num_key_value_heads=num_key_value_heads,
173+
head_dim=head_dim,
174+
cls_cache=cls_cache,
175+
**kwargs, # unused
146176
)
147-
res = dict(inputs=inputs, dynamic_shapes=shapes)
148177
else:
149178
if head_dim is None:
150179
assert config, "head_dim is None, the value cannot be set without a configuration"
@@ -244,6 +273,7 @@ def get_inputs(
244273
)
245274
res = dict(inputs=inputs, dynamic_shapes=shapes)
246275
if add_second_input:
276+
res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
247277
res["inputs2"] = get_inputs(
248278
model=model,
249279
config=config,

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,12 @@ def _check_frozen(
195195
if self.frozen:
196196
self.counter["ignored_backward_guard"] += 1
197197
# PATCHED: raised an exception instead of logging.
198+
import transformers
199+
198200
raise AssertionError(
199201
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
200-
f"this could result in accuracy problems"
202+
f"this could result in accuracy problems, transformers.__version__="
203+
f"{transformers.__version__!r}"
201204
)
202205

203206
def _set_replacement(

onnx_diagnostic/torch_models/validate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,9 @@ def _mk(key, flavour=flavour):
14631463
if verbose:
14641464
print(f"[validate_onnx_model] -- keys={keys}")
14651465
for k_input, k_expected, suffix in keys:
1466+
if k_input == "inputs_prompt":
1467+
# this must used onnx_generate
1468+
continue
14661469
# make_feeds
14671470
assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
14681471
assert k_expected in data, f"Unable to find {k_expected!r} in {sorted(data)}"

0 commit comments

Comments
 (0)