|
1 | 1 | import inspect |
2 | 2 | import math |
| 3 | +import os |
3 | 4 | from dataclasses import dataclass |
4 | 5 | from functools import wraps |
5 | | -from typing import Callable, List, Optional, Tuple |
| 6 | +from typing import Callable, List, Optional, Tuple, Union |
6 | 7 | import packaging.version as pv |
7 | 8 | import torch |
8 | 9 | import transformers |
9 | 10 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
10 | 11 | from transformers.cache_utils import StaticCache, Cache |
| 12 | +from transformers.generation.utils import ( |
| 13 | + GenerateDecoderOnlyOutput, |
| 14 | + GenerateEncoderDecoderOutput, |
| 15 | + GenerateNonBeamOutput, |
| 16 | + GenerationConfig, |
| 17 | + StoppingCriteriaList, |
| 18 | + LogitsProcessorList, |
| 19 | +) |
11 | 20 |
|
12 | 21 | try: |
13 | 22 | from transformers.cache_utils import parse_processor_args # noqa: F401 |
@@ -456,6 +465,11 @@ class patched_GenerationMixin: |
456 | 465 | "_cache_dependant_input_preparation", |
457 | 466 | "_cache_dependant_input_preparation_exporting", |
458 | 467 | "prepare_inputs_for_generation", |
| 468 | + ( |
| 469 | + "_sample" |
| 470 | + if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0") |
| 471 | + else None |
| 472 | + ), |
459 | 473 | ] |
460 | 474 | _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin |
461 | 475 |
|
@@ -588,7 +602,7 @@ def prepare_inputs_for_generation( |
588 | 602 | model_inputs = {} |
589 | 603 | # - some models don't have `Cache` support |
590 | 604 | # (which implies they don't expect `cache_position` in `forward`) |
591 | | - if self._supports_cache_class: |
| 605 | + if getattr(self, "_supports_cache_class", False): |
592 | 606 | model_inputs["cache_position"] = cache_position |
593 | 607 | # - `cache_position` was not a mandatory input in |
594 | 608 | # `prepare_inputs_for_generation` for those models, and this |
@@ -728,6 +742,174 @@ def prepare_inputs_for_generation( |
728 | 742 | model_inputs.pop("labels", None) |
729 | 743 | return model_inputs |
730 | 744 |
|
| 745 | + def _sample( |
| 746 | + self, |
| 747 | + input_ids: torch.LongTensor, |
| 748 | + logits_processor: LogitsProcessorList, |
| 749 | + stopping_criteria: StoppingCriteriaList, |
| 750 | + generation_config: GenerationConfig, |
| 751 | + synced_gpus: bool = False, |
| 752 | + streamer: Optional["BaseStreamer"] = None, # noqa: F821 |
| 753 | + **model_kwargs, |
| 754 | + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: |
| 755 | + # init values |
| 756 | + pad_token_id = generation_config._pad_token_tensor |
| 757 | + output_attentions = generation_config.output_attentions |
| 758 | + output_hidden_states = generation_config.output_hidden_states |
| 759 | + output_scores = generation_config.output_scores |
| 760 | + output_logits = generation_config.output_logits |
| 761 | + return_dict_in_generate = generation_config.return_dict_in_generate |
| 762 | + has_eos_stopping_criteria = any( |
| 763 | + hasattr(criteria, "eos_token_id") for criteria in stopping_criteria |
| 764 | + ) |
| 765 | + do_sample = generation_config.do_sample |
| 766 | + |
| 767 | + # init attention / hidden states / scores tuples |
| 768 | + scores = () if (return_dict_in_generate and output_scores) else None |
| 769 | + raw_logits = () if (return_dict_in_generate and output_logits) else None |
| 770 | + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| 771 | + cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| 772 | + decoder_hidden_states = ( |
| 773 | + () if (return_dict_in_generate and output_hidden_states) else None |
| 774 | + ) |
| 775 | + |
| 776 | + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states |
| 777 | + if return_dict_in_generate and self.config.is_encoder_decoder: |
| 778 | + encoder_attentions = ( |
| 779 | + model_kwargs["encoder_outputs"].get("attentions") |
| 780 | + if output_attentions |
| 781 | + else None |
| 782 | + ) |
| 783 | + encoder_hidden_states = ( |
| 784 | + model_kwargs["encoder_outputs"].get("hidden_states") |
| 785 | + if output_hidden_states |
| 786 | + else None |
| 787 | + ) |
| 788 | + |
| 789 | + # keep track of which sequences are already finished |
| 790 | + batch_size, cur_len = input_ids.shape[:2] |
| 791 | + this_peer_finished = False |
| 792 | + unfinished_sequences = torch.ones( |
| 793 | + batch_size, dtype=torch.long, device=input_ids.device |
| 794 | + ) |
| 795 | + model_kwargs = self._get_initial_cache_position( |
| 796 | + cur_len, input_ids.device, model_kwargs |
| 797 | + ) |
| 798 | + |
| 799 | + model_forward = self.__call__ |
| 800 | + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) |
| 801 | + if compile_forward: |
| 802 | + os.environ["TOKENIZERS_PARALLELISM"] = "0" |
| 803 | + # If we use FA2 and a static cache, we cannot compile with fullgraph |
| 804 | + model_forward = self.get_compiled_call(generation_config.compile_config) |
| 805 | + |
| 806 | + if generation_config.prefill_chunk_size is not None: |
| 807 | + model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) |
| 808 | + is_prefill = False |
| 809 | + else: |
| 810 | + is_prefill = True |
| 811 | + |
| 812 | + while self._has_unfinished_sequences( |
| 813 | + this_peer_finished, synced_gpus, device=input_ids.device |
| 814 | + ): |
| 815 | + # prepare model inputs |
| 816 | + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| 817 | + |
| 818 | + if is_prefill: |
| 819 | + outputs = self(**model_inputs, return_dict=True) |
| 820 | + is_prefill = False |
| 821 | + else: |
| 822 | + outputs = model_forward(**model_inputs, return_dict=True) |
| 823 | + |
| 824 | + model_kwargs = self._update_model_kwargs_for_generation( |
| 825 | + outputs, |
| 826 | + model_kwargs, |
| 827 | + is_encoder_decoder=self.config.is_encoder_decoder, |
| 828 | + ) |
| 829 | + if synced_gpus and this_peer_finished: |
| 830 | + continue |
| 831 | + |
| 832 | + next_token_logits = outputs.logits[:, -1, :].to( |
| 833 | + copy=True, dtype=torch.float32, device=input_ids.device |
| 834 | + ) |
| 835 | + |
| 836 | + # pre-process distribution |
| 837 | + next_token_scores = logits_processor(input_ids, next_token_logits) |
| 838 | + |
| 839 | + # Store scores, attentions and hidden_states when required |
| 840 | + if return_dict_in_generate: |
| 841 | + if output_scores: |
| 842 | + scores += (next_token_scores,) |
| 843 | + if output_logits: |
| 844 | + raw_logits += (next_token_logits,) |
| 845 | + if output_attentions: |
| 846 | + decoder_attentions += ( |
| 847 | + (outputs.decoder_attentions,) |
| 848 | + if self.config.is_encoder_decoder |
| 849 | + else (outputs.attentions,) |
| 850 | + ) |
| 851 | + if self.config.is_encoder_decoder: |
| 852 | + cross_attentions += (outputs.cross_attentions,) |
| 853 | + |
| 854 | + if output_hidden_states: |
| 855 | + decoder_hidden_states += ( |
| 856 | + (outputs.decoder_hidden_states,) |
| 857 | + if self.config.is_encoder_decoder |
| 858 | + else (outputs.hidden_states,) |
| 859 | + ) |
| 860 | + |
| 861 | + # token selection |
| 862 | + if do_sample: |
| 863 | + probs = torch.nn.functional.softmax(next_token_scores, dim=-1) |
| 864 | + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| 865 | + else: |
| 866 | + next_tokens = torch.argmax(next_token_scores, dim=-1) |
| 867 | + |
| 868 | + # finished sentences should have their next token be a padding token |
| 869 | + if has_eos_stopping_criteria: |
| 870 | + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( |
| 871 | + 1 - unfinished_sequences |
| 872 | + ) |
| 873 | + |
| 874 | + # update generated ids, model inputs, and length for next step |
| 875 | + # PATCHED: dimension issues when calling generate method |
| 876 | + input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
| 877 | + if streamer is not None: |
| 878 | + streamer.put(next_tokens.cpu()) |
| 879 | + |
| 880 | + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) |
| 881 | + this_peer_finished = unfinished_sequences.max() == 0 |
| 882 | + cur_len += 1 |
| 883 | + del outputs |
| 884 | + |
| 885 | + if streamer is not None: |
| 886 | + streamer.end() |
| 887 | + |
| 888 | + if return_dict_in_generate: |
| 889 | + if self.config.is_encoder_decoder: |
| 890 | + return GenerateEncoderDecoderOutput( |
| 891 | + sequences=input_ids, |
| 892 | + scores=scores, |
| 893 | + logits=raw_logits, |
| 894 | + encoder_attentions=encoder_attentions, |
| 895 | + encoder_hidden_states=encoder_hidden_states, |
| 896 | + decoder_attentions=decoder_attentions, |
| 897 | + cross_attentions=cross_attentions, |
| 898 | + decoder_hidden_states=decoder_hidden_states, |
| 899 | + past_key_values=model_kwargs.get("past_key_values"), |
| 900 | + ) |
| 901 | + else: |
| 902 | + return GenerateDecoderOnlyOutput( |
| 903 | + sequences=input_ids, |
| 904 | + scores=scores, |
| 905 | + logits=raw_logits, |
| 906 | + attentions=decoder_attentions, |
| 907 | + hidden_states=decoder_hidden_states, |
| 908 | + past_key_values=model_kwargs.get("past_key_values"), |
| 909 | + ) |
| 910 | + else: |
| 911 | + return input_ids |
| 912 | + |
731 | 913 |
|
732 | 914 | def patched__compute_dynamic_ntk_parameters( |
733 | 915 | config: Optional[transformers.PretrainedConfig] = None, |
|
0 commit comments