Skip to content

Commit 4a4b722

Browse files
committed
another step for gemma
1 parent db651ea commit 4a4b722

File tree

2 files changed

+90
-38
lines changed

2 files changed

+90
-38
lines changed

_unittests/ut_tasks/try_tasks.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
77
from onnx_diagnostic.helpers.torch_helper import steal_forward
88
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
9+
from onnx_diagnostic.torch_export_patches import torch_export_patches
910

1011

1112
class TestHuggingFaceHubModel(ExtTestCase):
@@ -873,15 +874,19 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
873874
# use_cache:bool,logits_to_keep:None,return_dict:bool)
874875

875876
print()
876-
# steal forward creates a bug...
877-
with steal_forward(
877+
with torch_export_patches(
878+
patch_torch=False, patch_sympy=False, patch_transformers=True
879+
), steal_forward(
878880
model,
879881
dump_file=self.get_dump_file("test_imagetext2text_generation_gemma3_4b_it.onnx"),
880882
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
881883
save_as_external_data=False,
882884
):
883885
generated_ids = model.generate(
884-
**inputs, max_new_tokens=282, do_sample=False, cache_implementation="static"
886+
**inputs,
887+
max_new_tokens=282,
888+
do_sample=False,
889+
cache_implementation="static",
885890
)
886891
output_text = processor.decode(
887892
generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 82 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import inspect
22
import math
3+
import os
34
from dataclasses import dataclass
45
from functools import wraps
5-
from typing import Callable, List, Optional, Tuple
6+
from typing import Callable, List, Optional, Tuple, Union
67
import packaging.version as pv
78
import torch
89
import transformers
@@ -114,6 +115,7 @@ def patched_eager_mask(
114115
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
115116
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
116117
_ = kwargs.pop("allow_is_causal_skip", None)
118+
# PATCHED: this line called the patched version of sdpa_mask
117119
mask = patched_sdpa_mask_recent_torch(
118120
batch_size=batch_size,
119121
cache_position=cache_position,
@@ -126,7 +128,7 @@ def patched_eager_mask(
126128
**kwargs,
127129
)
128130
min_dtype = torch.finfo(dtype).min
129-
# The patched line.
131+
# PATCHED: the following line
130132
# we need 0s where the tokens should be taken into account,
131133
# and -inf otherwise (mask is already of boolean type)
132134
# mask =
@@ -158,6 +160,7 @@ def patched_sdpa_mask_recent_torch(
158160
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
159161
batch_arange = torch.arange(batch_size, device=cache_position.device)
160162
head_arange = torch.arange(1, device=cache_position.device)
163+
# PATCHED: this line calls the patched version of vmap_for_bhqkv
161164
causal_mask = patched__vmap_for_bhqkv(mask_function)(
162165
batch_arange, head_arange, cache_position, kv_arange
163166
)
@@ -214,6 +217,7 @@ def lazy_initialization(self, key_states: torch.Tensor):
214217
self.dtype, self.device = key_states.dtype, key_states.device
215218
new_shape = list(key_states.shape)
216219
new_shape[-2] = 0
220+
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
217221
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
218222
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
219223
if patch_is_initialized:
@@ -248,6 +252,8 @@ def _patch_make_causal_mask(
248252
diagonal = past_key_values_length - sliding_window - 1
249253

250254
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
255+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
256+
# and used masked_fill instead of masked_fill_
251257
# In this case, the current implementation of torch fails (17/12/2024).
252258
# Try model Phi-3.5-Mini-Instruct.
253259
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
@@ -453,10 +459,18 @@ class patched_GenerationMixin:
453459
"""
454460

455461
_PATCHES_ = [
456-
"_cache_dependant_input_preparation",
457-
"_cache_dependant_input_preparation_exporting",
458-
"prepare_inputs_for_generation",
459-
"_sample",
462+
name
463+
for name in [
464+
"_cache_dependant_input_preparation",
465+
"_cache_dependant_input_preparation_exporting",
466+
(
467+
None
468+
if pv.Version(transformers.__version__) >= pv.Version("4.56")
469+
else "prepare_inputs_for_generation"
470+
),
471+
"_sample",
472+
]
473+
if name is not None
460474
]
461475
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
462476

@@ -732,13 +746,13 @@ def prepare_inputs_for_generation(
732746
def _sample(
733747
self,
734748
input_ids: torch.LongTensor,
735-
logits_processor: LogitsProcessorList,
736-
stopping_criteria: StoppingCriteriaList,
737-
generation_config: GenerationConfig,
749+
logits_processor: "LogitsProcessorList", # noqa: F821
750+
stopping_criteria: "StoppingCriteriaList", # noqa: F821
751+
generation_config: "GenerationConfig", # noqa: F821
738752
synced_gpus: bool = False,
739-
streamer: Optional["BaseStreamer"] = None,
753+
streamer: Optional["BaseStreamer"] = None, # noqa: F821
740754
**model_kwargs,
741-
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
755+
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
742756
"""
743757
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
744758
"""
@@ -749,28 +763,42 @@ def _sample(
749763
output_scores = generation_config.output_scores
750764
output_logits = generation_config.output_logits
751765
return_dict_in_generate = generation_config.return_dict_in_generate
752-
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
766+
has_eos_stopping_criteria = any(
767+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
768+
)
753769
do_sample = generation_config.do_sample
754770

755771
# init attention / hidden states / scores tuples
756772
scores = () if (return_dict_in_generate and output_scores) else None
757773
raw_logits = () if (return_dict_in_generate and output_logits) else None
758774
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
759775
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
760-
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
776+
decoder_hidden_states = (
777+
() if (return_dict_in_generate and output_hidden_states) else None
778+
)
761779

762780
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
763781
if return_dict_in_generate and self.config.is_encoder_decoder:
764-
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
782+
encoder_attentions = (
783+
model_kwargs["encoder_outputs"].get("attentions")
784+
if output_attentions
785+
else None
786+
)
765787
encoder_hidden_states = (
766-
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
788+
model_kwargs["encoder_outputs"].get("hidden_states")
789+
if output_hidden_states
790+
else None
767791
)
768792

769793
# keep track of which sequences are already finished
770794
batch_size, cur_len = input_ids.shape[:2]
771795
this_peer_finished = False
772-
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
773-
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
796+
unfinished_sequences = torch.ones(
797+
batch_size, dtype=torch.long, device=input_ids.device
798+
)
799+
model_kwargs = self._get_initial_cache_position(
800+
cur_len, input_ids.device, model_kwargs
801+
)
774802

775803
model_forward = self.__call__
776804
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
@@ -779,11 +807,10 @@ def _sample(
779807
# If we use FA2 and a static cache, we cannot compile with fullgraph
780808
if self.config._attn_implementation == "flash_attention_2":
781809
# only raise warning if the user passed an explicit compile-config
782-
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
783-
logger.warning_once(
784-
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
785-
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
786-
)
810+
if (
811+
generation_config.compile_config is not None
812+
and generation_config.compile_config.fullgraph
813+
):
787814
generation_config.compile_config.fullgraph = False
788815
model_forward = self.get_compiled_call(generation_config.compile_config)
789816

@@ -793,7 +820,9 @@ def _sample(
793820
else:
794821
is_prefill = True
795822

796-
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
823+
while self._has_unfinished_sequences(
824+
this_peer_finished, synced_gpus, device=input_ids.device
825+
):
797826
# prepare model inputs
798827
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
799828

@@ -803,7 +832,8 @@ def _sample(
803832
else:
804833
outputs = model_forward(**model_inputs, return_dict=True)
805834

806-
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
835+
# synced_gpus: don't waste resources running the code we don't need;
836+
# kwargs must be updated before skipping
807837
model_kwargs = self._update_model_kwargs_for_generation(
808838
outputs,
809839
model_kwargs,
@@ -812,9 +842,12 @@ def _sample(
812842
if synced_gpus and this_peer_finished:
813843
continue
814844

815-
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
845+
# Copy is needed to avoid keeping a hanging ref to outputs.logits
846+
# which may be very large for first iteration
816847
# (the clone itself is always small)
817-
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
848+
next_token_logits = outputs.logits[:, -1, :].to(
849+
copy=True, dtype=torch.float32, device=input_ids.device
850+
)
818851

819852
# pre-process distribution
820853
next_token_scores = logits_processor(input_ids, next_token_logits)
@@ -827,7 +860,9 @@ def _sample(
827860
raw_logits += (next_token_logits,)
828861
if output_attentions:
829862
decoder_attentions += (
830-
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
863+
(outputs.decoder_attentions,)
864+
if self.config.is_encoder_decoder
865+
else (outputs.attentions,)
831866
)
832867
if self.config.is_encoder_decoder:
833868
cross_attentions += (outputs.cross_attentions,)
@@ -841,35 +876,42 @@ def _sample(
841876

842877
# token selection
843878
if do_sample:
844-
probs = nn.functional.softmax(next_token_scores, dim=-1)
845-
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
879+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
846880
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
847881
else:
848882
next_tokens = torch.argmax(next_token_scores, dim=-1)
849883

850884
# finished sentences should have their next token be a padding token
851885
if has_eos_stopping_criteria:
852-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
886+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
887+
1 - unfinished_sequences
888+
)
853889

854890
# update generated ids, model inputs, and length for next step
855-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
891+
# PATCHED: the two following lines, next_tokens can 2D already for this model
892+
next_tokens_2d = (
893+
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
894+
)
895+
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
856896
if streamer is not None:
857897
streamer.put(next_tokens.cpu())
858898

859899
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
860900
this_peer_finished = unfinished_sequences.max() == 0
861901
cur_len += 1
862902

863-
# This is needed to properly delete outputs.logits which may be very large for first iteration
864-
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
903+
# This is needed to properly delete outputs.logits which may be very large
904+
# for first iteration
905+
# Otherwise a reference to outputs is kept which keeps
906+
# the logits alive in the next iteration
865907
del outputs
866908

867909
if streamer is not None:
868910
streamer.end()
869911

870912
if return_dict_in_generate:
871913
if self.config.is_encoder_decoder:
872-
return GenerateEncoderDecoderOutput(
914+
return transformers.generation.utils.GenerateEncoderDecoderOutput(
873915
sequences=input_ids,
874916
scores=scores,
875917
logits=raw_logits,
@@ -881,7 +923,7 @@ def _sample(
881923
past_key_values=model_kwargs.get("past_key_values"),
882924
)
883925
else:
884-
return GenerateDecoderOnlyOutput(
926+
return transformers.generation.utils.GenerateDecoderOnlyOutput(
885927
sequences=input_ids,
886928
scores=scores,
887929
logits=raw_logits,
@@ -955,6 +997,7 @@ def patched__compute_dynamic_ntk_parameters(
955997
if seq_len is None:
956998
seq_len = max_position_embeddings
957999
else:
1000+
# PATCHED: remove the line using max
9581001
torch._check(isinstance(seq_len, torch.Tensor))
9591002
seq_len = torch.maximum(
9601003
seq_len,
@@ -1060,6 +1103,7 @@ def longrope_frequency_update(self, position_ids, device):
10601103
)
10611104
original_inv_freq = self.original_inv_freq.to(device)
10621105

1106+
# PATCHED: uses torch.cond instead of a test
10631107
cond = (seq_len > original_max_position_embeddings).item()
10641108
inv_freq = torch.cond(
10651109
cond,
@@ -1131,6 +1175,7 @@ def dynamic_frequency_update(self, position_ids, device):
11311175

11321176
original_inv_freq = self.original_inv_freq.to(device)
11331177
cond = (seq_len >= self.original_max_seq_len).item()
1178+
# PATCHED: uses torch.cond instead of a test
11341179
inv_freq = torch.cond(
11351180
cond,
11361181
(lambda x, y: x.clone()),
@@ -1166,6 +1211,7 @@ def common_eager_attention_forward(
11661211

11671212
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
11681213
if attention_mask is not None:
1214+
# PATCHED
11691215
# The two following lines were added.
11701216
if attention_mask is not None and attention_mask.ndim == 4:
11711217
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
@@ -1238,6 +1284,7 @@ def patched_modeling_marian_eager_attention_forward(
12381284
class common_RotaryEmbedding(torch.nn.Module):
12391285
# This may cause some issues.
12401286
# @torch.no_grad()
1287+
# PATCHED: the decorator
12411288
@patched_dynamic_rope_update
12421289
def forward(self, x, position_ids):
12431290
inv_freq_expanded = (

0 commit comments

Comments
 (0)