@@ -456,6 +456,7 @@ class patched_GenerationMixin:
456456 "_cache_dependant_input_preparation" ,
457457 "_cache_dependant_input_preparation_exporting" ,
458458 "prepare_inputs_for_generation" ,
459+ "_sample" ,
459460 ]
460461 _PATCHED_CLASS_ = transformers .generation .utils .GenerationMixin
461462
@@ -728,6 +729,169 @@ def prepare_inputs_for_generation(
728729 model_inputs .pop ("labels" , None )
729730 return model_inputs
730731
732+ def _sample (
733+ self ,
734+ input_ids : torch .LongTensor ,
735+ logits_processor : LogitsProcessorList ,
736+ stopping_criteria : StoppingCriteriaList ,
737+ generation_config : GenerationConfig ,
738+ synced_gpus : bool = False ,
739+ streamer : Optional ["BaseStreamer" ] = None ,
740+ ** model_kwargs ,
741+ ) -> Union [GenerateNonBeamOutput , torch .LongTensor ]:
742+ """
743+ 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
744+ """
745+ # init values
746+ pad_token_id = generation_config ._pad_token_tensor
747+ output_attentions = generation_config .output_attentions
748+ output_hidden_states = generation_config .output_hidden_states
749+ output_scores = generation_config .output_scores
750+ output_logits = generation_config .output_logits
751+ return_dict_in_generate = generation_config .return_dict_in_generate
752+ has_eos_stopping_criteria = any (hasattr (criteria , "eos_token_id" ) for criteria in stopping_criteria )
753+ do_sample = generation_config .do_sample
754+
755+ # init attention / hidden states / scores tuples
756+ scores = () if (return_dict_in_generate and output_scores ) else None
757+ raw_logits = () if (return_dict_in_generate and output_logits ) else None
758+ decoder_attentions = () if (return_dict_in_generate and output_attentions ) else None
759+ cross_attentions = () if (return_dict_in_generate and output_attentions ) else None
760+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states ) else None
761+
762+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
763+ if return_dict_in_generate and self .config .is_encoder_decoder :
764+ encoder_attentions = model_kwargs ["encoder_outputs" ].get ("attentions" ) if output_attentions else None
765+ encoder_hidden_states = (
766+ model_kwargs ["encoder_outputs" ].get ("hidden_states" ) if output_hidden_states else None
767+ )
768+
769+ # keep track of which sequences are already finished
770+ batch_size , cur_len = input_ids .shape [:2 ]
771+ this_peer_finished = False
772+ unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
773+ model_kwargs = self ._get_initial_cache_position (cur_len , input_ids .device , model_kwargs )
774+
775+ model_forward = self .__call__
776+ compile_forward = self ._valid_auto_compile_criteria (model_kwargs , generation_config )
777+ if compile_forward :
778+ os .environ ["TOKENIZERS_PARALLELISM" ] = "0"
779+ # If we use FA2 and a static cache, we cannot compile with fullgraph
780+ if self .config ._attn_implementation == "flash_attention_2" :
781+ # only raise warning if the user passed an explicit compile-config
782+ if generation_config .compile_config is not None and generation_config .compile_config .fullgraph :
783+ logger .warning_once (
784+ "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
785+ "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
786+ )
787+ generation_config .compile_config .fullgraph = False
788+ model_forward = self .get_compiled_call (generation_config .compile_config )
789+
790+ if generation_config .prefill_chunk_size is not None :
791+ model_kwargs = self ._prefill_chunking (input_ids , generation_config , ** model_kwargs )
792+ is_prefill = False
793+ else :
794+ is_prefill = True
795+
796+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
797+ # prepare model inputs
798+ model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
799+
800+ if is_prefill :
801+ outputs = self (** model_inputs , return_dict = True )
802+ is_prefill = False
803+ else :
804+ outputs = model_forward (** model_inputs , return_dict = True )
805+
806+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
807+ model_kwargs = self ._update_model_kwargs_for_generation (
808+ outputs ,
809+ model_kwargs ,
810+ is_encoder_decoder = self .config .is_encoder_decoder ,
811+ )
812+ if synced_gpus and this_peer_finished :
813+ continue
814+
815+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
816+ # (the clone itself is always small)
817+ next_token_logits = outputs .logits [:, - 1 , :].to (copy = True , dtype = torch .float32 , device = input_ids .device )
818+
819+ # pre-process distribution
820+ next_token_scores = logits_processor (input_ids , next_token_logits )
821+
822+ # Store scores, attentions and hidden_states when required
823+ if return_dict_in_generate :
824+ if output_scores :
825+ scores += (next_token_scores ,)
826+ if output_logits :
827+ raw_logits += (next_token_logits ,)
828+ if output_attentions :
829+ decoder_attentions += (
830+ (outputs .decoder_attentions ,) if self .config .is_encoder_decoder else (outputs .attentions ,)
831+ )
832+ if self .config .is_encoder_decoder :
833+ cross_attentions += (outputs .cross_attentions ,)
834+
835+ if output_hidden_states :
836+ decoder_hidden_states += (
837+ (outputs .decoder_hidden_states ,)
838+ if self .config .is_encoder_decoder
839+ else (outputs .hidden_states ,)
840+ )
841+
842+ # token selection
843+ if do_sample :
844+ probs = nn .functional .softmax (next_token_scores , dim = - 1 )
845+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
846+ next_tokens = torch .multinomial (probs , num_samples = 1 ).squeeze (1 )
847+ else :
848+ next_tokens = torch .argmax (next_token_scores , dim = - 1 )
849+
850+ # finished sentences should have their next token be a padding token
851+ if has_eos_stopping_criteria :
852+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences )
853+
854+ # update generated ids, model inputs, and length for next step
855+ input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
856+ if streamer is not None :
857+ streamer .put (next_tokens .cpu ())
858+
859+ unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
860+ this_peer_finished = unfinished_sequences .max () == 0
861+ cur_len += 1
862+
863+ # This is needed to properly delete outputs.logits which may be very large for first iteration
864+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
865+ del outputs
866+
867+ if streamer is not None :
868+ streamer .end ()
869+
870+ if return_dict_in_generate :
871+ if self .config .is_encoder_decoder :
872+ return GenerateEncoderDecoderOutput (
873+ sequences = input_ids ,
874+ scores = scores ,
875+ logits = raw_logits ,
876+ encoder_attentions = encoder_attentions ,
877+ encoder_hidden_states = encoder_hidden_states ,
878+ decoder_attentions = decoder_attentions ,
879+ cross_attentions = cross_attentions ,
880+ decoder_hidden_states = decoder_hidden_states ,
881+ past_key_values = model_kwargs .get ("past_key_values" ),
882+ )
883+ else :
884+ return GenerateDecoderOnlyOutput (
885+ sequences = input_ids ,
886+ scores = scores ,
887+ logits = raw_logits ,
888+ attentions = decoder_attentions ,
889+ hidden_states = decoder_hidden_states ,
890+ past_key_values = model_kwargs .get ("past_key_values" ),
891+ )
892+ else :
893+ return input_ids
894+
731895
732896def patched__compute_dynamic_ntk_parameters (
733897 config : Optional [transformers .PretrainedConfig ] = None ,
0 commit comments