Skip to content

Commit db651ea

Browse files
committed
patch
1 parent 0113c67 commit db651ea

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ class patched_GenerationMixin:
456456
"_cache_dependant_input_preparation",
457457
"_cache_dependant_input_preparation_exporting",
458458
"prepare_inputs_for_generation",
459+
"_sample",
459460
]
460461
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
461462

@@ -728,6 +729,169 @@ def prepare_inputs_for_generation(
728729
model_inputs.pop("labels", None)
729730
return model_inputs
730731

732+
def _sample(
733+
self,
734+
input_ids: torch.LongTensor,
735+
logits_processor: LogitsProcessorList,
736+
stopping_criteria: StoppingCriteriaList,
737+
generation_config: GenerationConfig,
738+
synced_gpus: bool = False,
739+
streamer: Optional["BaseStreamer"] = None,
740+
**model_kwargs,
741+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
742+
"""
743+
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
744+
"""
745+
# init values
746+
pad_token_id = generation_config._pad_token_tensor
747+
output_attentions = generation_config.output_attentions
748+
output_hidden_states = generation_config.output_hidden_states
749+
output_scores = generation_config.output_scores
750+
output_logits = generation_config.output_logits
751+
return_dict_in_generate = generation_config.return_dict_in_generate
752+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
753+
do_sample = generation_config.do_sample
754+
755+
# init attention / hidden states / scores tuples
756+
scores = () if (return_dict_in_generate and output_scores) else None
757+
raw_logits = () if (return_dict_in_generate and output_logits) else None
758+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
759+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
760+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
761+
762+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
763+
if return_dict_in_generate and self.config.is_encoder_decoder:
764+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
765+
encoder_hidden_states = (
766+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
767+
)
768+
769+
# keep track of which sequences are already finished
770+
batch_size, cur_len = input_ids.shape[:2]
771+
this_peer_finished = False
772+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
773+
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
774+
775+
model_forward = self.__call__
776+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
777+
if compile_forward:
778+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
779+
# If we use FA2 and a static cache, we cannot compile with fullgraph
780+
if self.config._attn_implementation == "flash_attention_2":
781+
# only raise warning if the user passed an explicit compile-config
782+
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
783+
logger.warning_once(
784+
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
785+
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
786+
)
787+
generation_config.compile_config.fullgraph = False
788+
model_forward = self.get_compiled_call(generation_config.compile_config)
789+
790+
if generation_config.prefill_chunk_size is not None:
791+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
792+
is_prefill = False
793+
else:
794+
is_prefill = True
795+
796+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
797+
# prepare model inputs
798+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
799+
800+
if is_prefill:
801+
outputs = self(**model_inputs, return_dict=True)
802+
is_prefill = False
803+
else:
804+
outputs = model_forward(**model_inputs, return_dict=True)
805+
806+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
807+
model_kwargs = self._update_model_kwargs_for_generation(
808+
outputs,
809+
model_kwargs,
810+
is_encoder_decoder=self.config.is_encoder_decoder,
811+
)
812+
if synced_gpus and this_peer_finished:
813+
continue
814+
815+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
816+
# (the clone itself is always small)
817+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
818+
819+
# pre-process distribution
820+
next_token_scores = logits_processor(input_ids, next_token_logits)
821+
822+
# Store scores, attentions and hidden_states when required
823+
if return_dict_in_generate:
824+
if output_scores:
825+
scores += (next_token_scores,)
826+
if output_logits:
827+
raw_logits += (next_token_logits,)
828+
if output_attentions:
829+
decoder_attentions += (
830+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
831+
)
832+
if self.config.is_encoder_decoder:
833+
cross_attentions += (outputs.cross_attentions,)
834+
835+
if output_hidden_states:
836+
decoder_hidden_states += (
837+
(outputs.decoder_hidden_states,)
838+
if self.config.is_encoder_decoder
839+
else (outputs.hidden_states,)
840+
)
841+
842+
# token selection
843+
if do_sample:
844+
probs = nn.functional.softmax(next_token_scores, dim=-1)
845+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
846+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
847+
else:
848+
next_tokens = torch.argmax(next_token_scores, dim=-1)
849+
850+
# finished sentences should have their next token be a padding token
851+
if has_eos_stopping_criteria:
852+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
853+
854+
# update generated ids, model inputs, and length for next step
855+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
856+
if streamer is not None:
857+
streamer.put(next_tokens.cpu())
858+
859+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
860+
this_peer_finished = unfinished_sequences.max() == 0
861+
cur_len += 1
862+
863+
# This is needed to properly delete outputs.logits which may be very large for first iteration
864+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
865+
del outputs
866+
867+
if streamer is not None:
868+
streamer.end()
869+
870+
if return_dict_in_generate:
871+
if self.config.is_encoder_decoder:
872+
return GenerateEncoderDecoderOutput(
873+
sequences=input_ids,
874+
scores=scores,
875+
logits=raw_logits,
876+
encoder_attentions=encoder_attentions,
877+
encoder_hidden_states=encoder_hidden_states,
878+
decoder_attentions=decoder_attentions,
879+
cross_attentions=cross_attentions,
880+
decoder_hidden_states=decoder_hidden_states,
881+
past_key_values=model_kwargs.get("past_key_values"),
882+
)
883+
else:
884+
return GenerateDecoderOnlyOutput(
885+
sequences=input_ids,
886+
scores=scores,
887+
logits=raw_logits,
888+
attentions=decoder_attentions,
889+
hidden_states=decoder_hidden_states,
890+
past_key_values=model_kwargs.get("past_key_values"),
891+
)
892+
else:
893+
return input_ids
894+
731895

732896
def patched__compute_dynamic_ntk_parameters(
733897
config: Optional[transformers.PretrainedConfig] = None,

0 commit comments

Comments
 (0)