11import inspect
22import math
3+ import os
34from dataclasses import dataclass
45from functools import wraps
5- from typing import Callable , List , Optional , Tuple
6+ from typing import Callable , List , Optional , Tuple , Union
67import packaging .version as pv
78import torch
89import transformers
@@ -114,6 +115,7 @@ def patched_eager_mask(
114115 """manual patch for function ``transformers.masking_utils.eager_mask``."""
115116 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
116117 _ = kwargs .pop ("allow_is_causal_skip" , None )
118+ # PATCHED: this line called the patched version of sdpa_mask
117119 mask = patched_sdpa_mask_recent_torch (
118120 batch_size = batch_size ,
119121 cache_position = cache_position ,
@@ -126,7 +128,7 @@ def patched_eager_mask(
126128 ** kwargs ,
127129 )
128130 min_dtype = torch .finfo (dtype ).min
129- # The patched line.
131+ # PATCHED: the following line
130132 # we need 0s where the tokens should be taken into account,
131133 # and -inf otherwise (mask is already of boolean type)
132134 # mask =
@@ -158,6 +160,7 @@ def patched_sdpa_mask_recent_torch(
158160 mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
159161 batch_arange = torch .arange (batch_size , device = cache_position .device )
160162 head_arange = torch .arange (1 , device = cache_position .device )
163+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
161164 causal_mask = patched__vmap_for_bhqkv (mask_function )(
162165 batch_arange , head_arange , cache_position , kv_arange
163166 )
@@ -214,6 +217,7 @@ def lazy_initialization(self, key_states: torch.Tensor):
214217 self .dtype , self .device = key_states .dtype , key_states .device
215218 new_shape = list (key_states .shape )
216219 new_shape [- 2 ] = 0
220+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
217221 self .keys = torch .empty (new_shape , dtype = self .dtype , device = self .device )
218222 self .values = torch .empty (new_shape , dtype = self .dtype , device = self .device )
219223 if patch_is_initialized :
@@ -248,6 +252,8 @@ def _patch_make_causal_mask(
248252 diagonal = past_key_values_length - sliding_window - 1
249253
250254 context_mask = torch .tril (torch .ones_like (mask , dtype = torch .bool ), diagonal = diagonal )
255+ # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
256+ # and used masked_fill instead of masked_fill_
251257 # In this case, the current implementation of torch fails (17/12/2024).
252258 # Try model Phi-3.5-Mini-Instruct.
253259 mask = mask .masked_fill (context_mask , torch .finfo (dtype ).min )
@@ -453,10 +459,18 @@ class patched_GenerationMixin:
453459 """
454460
455461 _PATCHES_ = [
456- "_cache_dependant_input_preparation" ,
457- "_cache_dependant_input_preparation_exporting" ,
458- "prepare_inputs_for_generation" ,
459- "_sample" ,
462+ name
463+ for name in [
464+ "_cache_dependant_input_preparation" ,
465+ "_cache_dependant_input_preparation_exporting" ,
466+ (
467+ None
468+ if pv .Version (transformers .__version__ ) >= pv .Version ("4.56" )
469+ else "prepare_inputs_for_generation"
470+ ),
471+ "_sample" ,
472+ ]
473+ if name is not None
460474 ]
461475 _PATCHED_CLASS_ = transformers .generation .utils .GenerationMixin
462476
@@ -732,13 +746,13 @@ def prepare_inputs_for_generation(
732746 def _sample (
733747 self ,
734748 input_ids : torch .LongTensor ,
735- logits_processor : LogitsProcessorList ,
736- stopping_criteria : StoppingCriteriaList ,
737- generation_config : GenerationConfig ,
749+ logits_processor : " LogitsProcessorList" , # noqa: F821
750+ stopping_criteria : " StoppingCriteriaList" , # noqa: F821
751+ generation_config : " GenerationConfig" , # noqa: F821
738752 synced_gpus : bool = False ,
739- streamer : Optional ["BaseStreamer" ] = None ,
753+ streamer : Optional ["BaseStreamer" ] = None , # noqa: F821
740754 ** model_kwargs ,
741- ) -> Union [GenerateNonBeamOutput , torch .LongTensor ]:
755+ ) -> Union [" GenerateNonBeamOutput" , torch .LongTensor ]: # noqa: F821
742756 """
743757 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
744758 """
@@ -749,28 +763,42 @@ def _sample(
749763 output_scores = generation_config .output_scores
750764 output_logits = generation_config .output_logits
751765 return_dict_in_generate = generation_config .return_dict_in_generate
752- has_eos_stopping_criteria = any (hasattr (criteria , "eos_token_id" ) for criteria in stopping_criteria )
766+ has_eos_stopping_criteria = any (
767+ hasattr (criteria , "eos_token_id" ) for criteria in stopping_criteria
768+ )
753769 do_sample = generation_config .do_sample
754770
755771 # init attention / hidden states / scores tuples
756772 scores = () if (return_dict_in_generate and output_scores ) else None
757773 raw_logits = () if (return_dict_in_generate and output_logits ) else None
758774 decoder_attentions = () if (return_dict_in_generate and output_attentions ) else None
759775 cross_attentions = () if (return_dict_in_generate and output_attentions ) else None
760- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states ) else None
776+ decoder_hidden_states = (
777+ () if (return_dict_in_generate and output_hidden_states ) else None
778+ )
761779
762780 # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
763781 if return_dict_in_generate and self .config .is_encoder_decoder :
764- encoder_attentions = model_kwargs ["encoder_outputs" ].get ("attentions" ) if output_attentions else None
782+ encoder_attentions = (
783+ model_kwargs ["encoder_outputs" ].get ("attentions" )
784+ if output_attentions
785+ else None
786+ )
765787 encoder_hidden_states = (
766- model_kwargs ["encoder_outputs" ].get ("hidden_states" ) if output_hidden_states else None
788+ model_kwargs ["encoder_outputs" ].get ("hidden_states" )
789+ if output_hidden_states
790+ else None
767791 )
768792
769793 # keep track of which sequences are already finished
770794 batch_size , cur_len = input_ids .shape [:2 ]
771795 this_peer_finished = False
772- unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
773- model_kwargs = self ._get_initial_cache_position (cur_len , input_ids .device , model_kwargs )
796+ unfinished_sequences = torch .ones (
797+ batch_size , dtype = torch .long , device = input_ids .device
798+ )
799+ model_kwargs = self ._get_initial_cache_position (
800+ cur_len , input_ids .device , model_kwargs
801+ )
774802
775803 model_forward = self .__call__
776804 compile_forward = self ._valid_auto_compile_criteria (model_kwargs , generation_config )
@@ -779,11 +807,10 @@ def _sample(
779807 # If we use FA2 and a static cache, we cannot compile with fullgraph
780808 if self .config ._attn_implementation == "flash_attention_2" :
781809 # only raise warning if the user passed an explicit compile-config
782- if generation_config .compile_config is not None and generation_config .compile_config .fullgraph :
783- logger .warning_once (
784- "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
785- "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
786- )
810+ if (
811+ generation_config .compile_config is not None
812+ and generation_config .compile_config .fullgraph
813+ ):
787814 generation_config .compile_config .fullgraph = False
788815 model_forward = self .get_compiled_call (generation_config .compile_config )
789816
@@ -793,7 +820,9 @@ def _sample(
793820 else :
794821 is_prefill = True
795822
796- while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
823+ while self ._has_unfinished_sequences (
824+ this_peer_finished , synced_gpus , device = input_ids .device
825+ ):
797826 # prepare model inputs
798827 model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
799828
@@ -803,7 +832,8 @@ def _sample(
803832 else :
804833 outputs = model_forward (** model_inputs , return_dict = True )
805834
806- # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
835+ # synced_gpus: don't waste resources running the code we don't need;
836+ # kwargs must be updated before skipping
807837 model_kwargs = self ._update_model_kwargs_for_generation (
808838 outputs ,
809839 model_kwargs ,
@@ -812,9 +842,12 @@ def _sample(
812842 if synced_gpus and this_peer_finished :
813843 continue
814844
815- # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
845+ # Copy is needed to avoid keeping a hanging ref to outputs.logits
846+ # which may be very large for first iteration
816847 # (the clone itself is always small)
817- next_token_logits = outputs .logits [:, - 1 , :].to (copy = True , dtype = torch .float32 , device = input_ids .device )
848+ next_token_logits = outputs .logits [:, - 1 , :].to (
849+ copy = True , dtype = torch .float32 , device = input_ids .device
850+ )
818851
819852 # pre-process distribution
820853 next_token_scores = logits_processor (input_ids , next_token_logits )
@@ -827,7 +860,9 @@ def _sample(
827860 raw_logits += (next_token_logits ,)
828861 if output_attentions :
829862 decoder_attentions += (
830- (outputs .decoder_attentions ,) if self .config .is_encoder_decoder else (outputs .attentions ,)
863+ (outputs .decoder_attentions ,)
864+ if self .config .is_encoder_decoder
865+ else (outputs .attentions ,)
831866 )
832867 if self .config .is_encoder_decoder :
833868 cross_attentions += (outputs .cross_attentions ,)
@@ -841,35 +876,42 @@ def _sample(
841876
842877 # token selection
843878 if do_sample :
844- probs = nn .functional .softmax (next_token_scores , dim = - 1 )
845- # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
879+ probs = torch .nn .functional .softmax (next_token_scores , dim = - 1 )
846880 next_tokens = torch .multinomial (probs , num_samples = 1 ).squeeze (1 )
847881 else :
848882 next_tokens = torch .argmax (next_token_scores , dim = - 1 )
849883
850884 # finished sentences should have their next token be a padding token
851885 if has_eos_stopping_criteria :
852- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences )
886+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
887+ 1 - unfinished_sequences
888+ )
853889
854890 # update generated ids, model inputs, and length for next step
855- input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
891+ # PATCHED: the two following lines, next_tokens can 2D already for this model
892+ next_tokens_2d = (
893+ next_tokens if len (next_tokens .shape ) == 2 else next_tokens [:, None ]
894+ )
895+ input_ids = torch .cat ([input_ids , next_tokens_2d ], dim = - 1 )
856896 if streamer is not None :
857897 streamer .put (next_tokens .cpu ())
858898
859899 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
860900 this_peer_finished = unfinished_sequences .max () == 0
861901 cur_len += 1
862902
863- # This is needed to properly delete outputs.logits which may be very large for first iteration
864- # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
903+ # This is needed to properly delete outputs.logits which may be very large
904+ # for first iteration
905+ # Otherwise a reference to outputs is kept which keeps
906+ # the logits alive in the next iteration
865907 del outputs
866908
867909 if streamer is not None :
868910 streamer .end ()
869911
870912 if return_dict_in_generate :
871913 if self .config .is_encoder_decoder :
872- return GenerateEncoderDecoderOutput (
914+ return transformers . generation . utils . GenerateEncoderDecoderOutput (
873915 sequences = input_ids ,
874916 scores = scores ,
875917 logits = raw_logits ,
@@ -881,7 +923,7 @@ def _sample(
881923 past_key_values = model_kwargs .get ("past_key_values" ),
882924 )
883925 else :
884- return GenerateDecoderOnlyOutput (
926+ return transformers . generation . utils . GenerateDecoderOnlyOutput (
885927 sequences = input_ids ,
886928 scores = scores ,
887929 logits = raw_logits ,
@@ -955,6 +997,7 @@ def patched__compute_dynamic_ntk_parameters(
955997 if seq_len is None :
956998 seq_len = max_position_embeddings
957999 else :
1000+ # PATCHED: remove the line using max
9581001 torch ._check (isinstance (seq_len , torch .Tensor ))
9591002 seq_len = torch .maximum (
9601003 seq_len ,
@@ -1060,6 +1103,7 @@ def longrope_frequency_update(self, position_ids, device):
10601103 )
10611104 original_inv_freq = self .original_inv_freq .to (device )
10621105
1106+ # PATCHED: uses torch.cond instead of a test
10631107 cond = (seq_len > original_max_position_embeddings ).item ()
10641108 inv_freq = torch .cond (
10651109 cond ,
@@ -1131,6 +1175,7 @@ def dynamic_frequency_update(self, position_ids, device):
11311175
11321176 original_inv_freq = self .original_inv_freq .to (device )
11331177 cond = (seq_len >= self .original_max_seq_len ).item ()
1178+ # PATCHED: uses torch.cond instead of a test
11341179 inv_freq = torch .cond (
11351180 cond ,
11361181 (lambda x , y : x .clone ()),
@@ -1166,6 +1211,7 @@ def common_eager_attention_forward(
11661211
11671212 attn_weights = torch .matmul (query , key .transpose (2 , 3 )) * scaling
11681213 if attention_mask is not None :
1214+ # PATCHED
11691215 # The two following lines were added.
11701216 if attention_mask is not None and attention_mask .ndim == 4 :
11711217 attention_mask = attention_mask [:, :, :, : key .shape [- 2 ]]
@@ -1238,6 +1284,7 @@ def patched_modeling_marian_eager_attention_forward(
12381284class common_RotaryEmbedding (torch .nn .Module ):
12391285 # This may cause some issues.
12401286 # @torch.no_grad()
1287+ # PATCHED: the decorator
12411288 @patched_dynamic_rope_update
12421289 def forward (self , x , position_ids ):
12431290 inv_freq_expanded = (
0 commit comments