Skip to content

Commit 126e585

Browse files
committed
Add set of inputs for empty cache
1 parent 943e44b commit 126e585

File tree

7 files changed

+316
-33
lines changed

7 files changed

+316
-33
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ def test_text_generation(self):
4848
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4949
)
5050

51+
def test_text_generation_empty_cache(self):
52+
mid = "arnir0/Tiny-LLM"
53+
data = get_untrained_model_with_inputs(mid, add_second_input=True)
54+
model, inputs = data["model"], data["inputs"]
55+
self.assertIn("inputs_empty_cache", data)
56+
empty_inputs = torch_deepcopy(data["inputs_empty_cache"])
57+
expected = model(**empty_inputs)
58+
self.assertEqual(
59+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
60+
)
61+
with torch_export_patches(patch_transformers=True, verbose=1):
62+
ep = torch.export.export(
63+
model,
64+
(),
65+
kwargs=inputs,
66+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
67+
)
68+
got = ep.module()(**inputs)
69+
self.assertEqualArrayAny(expected, got)
70+
5171
@hide_stdout()
5272
def test_automatic_speech_recognition_float32(self):
5373
mid = "openai/whisper-tiny"

_unittests/ut_tasks/try_tasks.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from onnx_diagnostic.helpers import string_type
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
66
from onnx_diagnostic.helpers.torch_helper import steal_forward
7+
from onnx_diagnostic.torch_export_patches import torch_export_patches
78
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
89

910

@@ -130,6 +131,49 @@ def test_text2text_generation_static(self):
130131
)
131132
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
132133

134+
@never_test()
135+
def test_text_generation_tiny_llm(self):
136+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k tiny_llm
137+
"""
138+
dict(cache_position:T7s21,
139+
past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),
140+
input_ids:T7s1x21,
141+
position_ids:T7s1x21
142+
attention_mask:T1s1x21)
143+
dict(cache_position:T7s1,
144+
past_key_values:DynamicCache(key_cache=#32[T1s1x8x21x128,...],
145+
value_cache=#32[T1s1x8x21x128,...]),
146+
input_ids:T7s1x21,
147+
position_ids:T7s1x1
148+
attention_mask:T1s1x1)
149+
"""
150+
from transformers import AutoTokenizer, AutoModelForCausalLM
151+
152+
tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM")
153+
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-mini-instruct")
154+
155+
text = "def greet(user): print(f'hello <extra_id_0>!')"
156+
input_ids = tokenizer(text, return_tensors="pt").input_ids.reshape((1, -1))
157+
mask = (
158+
torch.tensor([1 for i in range(input_ids.shape[1])])
159+
.to(torch.int64)
160+
.reshape((1, -1))
161+
)
162+
position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).reshape((1, -1))
163+
164+
# simply generate a single sequence
165+
print()
166+
with torch_export_patches(
167+
patch_transformers=True, patch_torch=False, patch_sympy=False
168+
), steal_forward(model):
169+
generated_ids = model.generate(
170+
input_ids=input_ids,
171+
max_length=100,
172+
attention_mask=mask,
173+
position_ids=position_ids,
174+
)
175+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
176+
133177
@never_test()
134178
def test_text_generation_phi4_mini(self):
135179
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,12 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
765765

766766

767767
def torch_deepcopy(value: Any) -> Any:
768-
"""Makes a deepcopy."""
768+
"""
769+
Makes a deep copy.
770+
771+
:param value: any value
772+
:return: a deep copy
773+
"""
769774
if value is None:
770775
return None
771776
if isinstance(value, (int, float, str)):

onnx_diagnostic/tasks/text_generation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,21 @@ def get_inputs(
269269
add_second_input=0,
270270
**kwargs,
271271
)["inputs"]
272+
res["inputs_empty_cache"] = get_inputs(
273+
model=model,
274+
config=config,
275+
dummy_max_token_id=dummy_max_token_id,
276+
num_hidden_layers=num_hidden_layers,
277+
batch_size=batch_size,
278+
sequence_length=0,
279+
sequence_length2=sequence_length2,
280+
dynamic_rope=dynamic_rope,
281+
num_key_value_heads=num_key_value_heads,
282+
head_dim=head_dim,
283+
cls_cache=cls_cache,
284+
add_second_input=0,
285+
**kwargs,
286+
)["inputs"]
272287
return res
273288

274289

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 184 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import inspect
22
import math
3+
import os
34
from dataclasses import dataclass
45
from functools import wraps
5-
from typing import Callable, List, Optional, Tuple
6+
from typing import Callable, List, Optional, Tuple, Union
67
import packaging.version as pv
78
import torch
89
import transformers
910
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1011
from transformers.cache_utils import StaticCache, Cache
12+
from transformers.generation.utils import (
13+
GenerateDecoderOnlyOutput,
14+
GenerateEncoderDecoderOutput,
15+
GenerateNonBeamOutput,
16+
GenerationConfig,
17+
StoppingCriteriaList,
18+
LogitsProcessorList,
19+
)
1120

1221
try:
1322
from transformers.cache_utils import parse_processor_args # noqa: F401
@@ -456,6 +465,11 @@ class patched_GenerationMixin:
456465
"_cache_dependant_input_preparation",
457466
"_cache_dependant_input_preparation_exporting",
458467
"prepare_inputs_for_generation",
468+
(
469+
"_sample"
470+
if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
471+
else None
472+
),
459473
]
460474
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
461475

@@ -588,7 +602,7 @@ def prepare_inputs_for_generation(
588602
model_inputs = {}
589603
# - some models don't have `Cache` support
590604
# (which implies they don't expect `cache_position` in `forward`)
591-
if self._supports_cache_class:
605+
if getattr(self, "_supports_cache_class", False):
592606
model_inputs["cache_position"] = cache_position
593607
# - `cache_position` was not a mandatory input in
594608
# `prepare_inputs_for_generation` for those models, and this
@@ -728,6 +742,174 @@ def prepare_inputs_for_generation(
728742
model_inputs.pop("labels", None)
729743
return model_inputs
730744

745+
def _sample(
746+
self,
747+
input_ids: torch.LongTensor,
748+
logits_processor: LogitsProcessorList,
749+
stopping_criteria: StoppingCriteriaList,
750+
generation_config: GenerationConfig,
751+
synced_gpus: bool = False,
752+
streamer: Optional["BaseStreamer"] = None, # noqa: F821
753+
**model_kwargs,
754+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
755+
# init values
756+
pad_token_id = generation_config._pad_token_tensor
757+
output_attentions = generation_config.output_attentions
758+
output_hidden_states = generation_config.output_hidden_states
759+
output_scores = generation_config.output_scores
760+
output_logits = generation_config.output_logits
761+
return_dict_in_generate = generation_config.return_dict_in_generate
762+
has_eos_stopping_criteria = any(
763+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
764+
)
765+
do_sample = generation_config.do_sample
766+
767+
# init attention / hidden states / scores tuples
768+
scores = () if (return_dict_in_generate and output_scores) else None
769+
raw_logits = () if (return_dict_in_generate and output_logits) else None
770+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
771+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
772+
decoder_hidden_states = (
773+
() if (return_dict_in_generate and output_hidden_states) else None
774+
)
775+
776+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
777+
if return_dict_in_generate and self.config.is_encoder_decoder:
778+
encoder_attentions = (
779+
model_kwargs["encoder_outputs"].get("attentions")
780+
if output_attentions
781+
else None
782+
)
783+
encoder_hidden_states = (
784+
model_kwargs["encoder_outputs"].get("hidden_states")
785+
if output_hidden_states
786+
else None
787+
)
788+
789+
# keep track of which sequences are already finished
790+
batch_size, cur_len = input_ids.shape[:2]
791+
this_peer_finished = False
792+
unfinished_sequences = torch.ones(
793+
batch_size, dtype=torch.long, device=input_ids.device
794+
)
795+
model_kwargs = self._get_initial_cache_position(
796+
cur_len, input_ids.device, model_kwargs
797+
)
798+
799+
model_forward = self.__call__
800+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
801+
if compile_forward:
802+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
803+
# If we use FA2 and a static cache, we cannot compile with fullgraph
804+
model_forward = self.get_compiled_call(generation_config.compile_config)
805+
806+
if generation_config.prefill_chunk_size is not None:
807+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
808+
is_prefill = False
809+
else:
810+
is_prefill = True
811+
812+
while self._has_unfinished_sequences(
813+
this_peer_finished, synced_gpus, device=input_ids.device
814+
):
815+
# prepare model inputs
816+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
817+
818+
if is_prefill:
819+
outputs = self(**model_inputs, return_dict=True)
820+
is_prefill = False
821+
else:
822+
outputs = model_forward(**model_inputs, return_dict=True)
823+
824+
model_kwargs = self._update_model_kwargs_for_generation(
825+
outputs,
826+
model_kwargs,
827+
is_encoder_decoder=self.config.is_encoder_decoder,
828+
)
829+
if synced_gpus and this_peer_finished:
830+
continue
831+
832+
next_token_logits = outputs.logits[:, -1, :].to(
833+
copy=True, dtype=torch.float32, device=input_ids.device
834+
)
835+
836+
# pre-process distribution
837+
next_token_scores = logits_processor(input_ids, next_token_logits)
838+
839+
# Store scores, attentions and hidden_states when required
840+
if return_dict_in_generate:
841+
if output_scores:
842+
scores += (next_token_scores,)
843+
if output_logits:
844+
raw_logits += (next_token_logits,)
845+
if output_attentions:
846+
decoder_attentions += (
847+
(outputs.decoder_attentions,)
848+
if self.config.is_encoder_decoder
849+
else (outputs.attentions,)
850+
)
851+
if self.config.is_encoder_decoder:
852+
cross_attentions += (outputs.cross_attentions,)
853+
854+
if output_hidden_states:
855+
decoder_hidden_states += (
856+
(outputs.decoder_hidden_states,)
857+
if self.config.is_encoder_decoder
858+
else (outputs.hidden_states,)
859+
)
860+
861+
# token selection
862+
if do_sample:
863+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
864+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
865+
else:
866+
next_tokens = torch.argmax(next_token_scores, dim=-1)
867+
868+
# finished sentences should have their next token be a padding token
869+
if has_eos_stopping_criteria:
870+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
871+
1 - unfinished_sequences
872+
)
873+
874+
# update generated ids, model inputs, and length for next step
875+
# PATCHED: dimension issues when calling generate method
876+
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
877+
if streamer is not None:
878+
streamer.put(next_tokens.cpu())
879+
880+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
881+
this_peer_finished = unfinished_sequences.max() == 0
882+
cur_len += 1
883+
del outputs
884+
885+
if streamer is not None:
886+
streamer.end()
887+
888+
if return_dict_in_generate:
889+
if self.config.is_encoder_decoder:
890+
return GenerateEncoderDecoderOutput(
891+
sequences=input_ids,
892+
scores=scores,
893+
logits=raw_logits,
894+
encoder_attentions=encoder_attentions,
895+
encoder_hidden_states=encoder_hidden_states,
896+
decoder_attentions=decoder_attentions,
897+
cross_attentions=cross_attentions,
898+
decoder_hidden_states=decoder_hidden_states,
899+
past_key_values=model_kwargs.get("past_key_values"),
900+
)
901+
else:
902+
return GenerateDecoderOnlyOutput(
903+
sequences=input_ids,
904+
scores=scores,
905+
logits=raw_logits,
906+
attentions=decoder_attentions,
907+
hidden_states=decoder_hidden_states,
908+
past_key_values=model_kwargs.get("past_key_values"),
909+
)
910+
else:
911+
return input_ids
912+
731913

732914
def patched__compute_dynamic_ntk_parameters(
733915
config: Optional[transformers.PretrainedConfig] = None,

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_untrained_model_with_inputs(
5757
to get a smaller model
5858
:param use_pretrained: download the pretrained weights as well
5959
:param use_preinstalled: use preinstalled configurations
60-
:param add_second_input: provides a second inputs to check a model
60+
:param add_second_input: provides others inputs to check a model
6161
supports different shapes
6262
:param subfolder: subfolder to use for this model id
6363
:param use_only_preinstalled: use only preinstalled version

0 commit comments

Comments
 (0)