|
30 | 30 | DynamicCache, |
31 | 31 | EncoderDecoderCache, |
32 | 32 | OffloadedCache, |
33 | | - QuantizedCacheConfig, |
| 33 | + QuantizedCache, |
34 | 34 | StaticCache, |
| 35 | + SlidingWindowCache, |
| 36 | + SinkCache, |
| 37 | + HybridCache, |
| 38 | + HybridChunkedCache, |
35 | 39 | ) |
36 | 40 | from transformers.configuration_utils import PretrainedConfig |
37 | 41 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
55 | 59 | AssistedCandidateGeneratorDifferentTokenizers, |
56 | 60 | CandidateGenerator, |
57 | 61 | PromptLookupCandidateGenerator, |
58 | | - _crop_past_key_values, |
59 | 62 | _prepare_attention_mask, |
60 | 63 | _prepare_token_type_ids, |
61 | 64 | ) |
62 | 65 | from transformers.generation.configuration_utils import ( |
63 | | - NEED_SETUP_CACHE_CLASSES_MAPPING, |
64 | | - QUANT_BACKEND_CLASSES_MAPPING, |
65 | 66 | GenerationConfig, |
66 | 67 | GenerationMode, |
67 | 68 | ) |
|
111 | 112 |
|
112 | 113 | logger = logging.get_logger(__name__) |
113 | 114 |
|
| 115 | +# Compatibility with transformers 4.57.1+ |
| 116 | +# These mappings are needed for the removed constants |
| 117 | +NEED_SETUP_CACHE_CLASSES_MAPPING = { |
| 118 | + "auto": Cache, |
| 119 | + "dynamic": DynamicCache, |
| 120 | + "static": StaticCache, |
| 121 | + "offloaded": OffloadedCache, |
| 122 | + "sliding_window": SlidingWindowCache, |
| 123 | + "sink": SinkCache, |
| 124 | + "hybrid": HybridCache, |
| 125 | + "hybrid_chunked": HybridChunkedCache, |
| 126 | +} |
| 127 | + |
| 128 | +# Mapping for quantized cache backends |
| 129 | +QUANT_BACKEND_CLASSES_MAPPING = { |
| 130 | + "quanto": QuantizedCache, |
| 131 | + "hqq": QuantizedCache, |
| 132 | +} |
| 133 | + |
| 134 | +# Compatibility class for removed QuantizedCacheConfig |
| 135 | +class QuantizedCacheConfig: |
| 136 | + def __init__(self, backend: str = "quanto", nbits: int = 4, |
| 137 | + axis_key: int = 0, axis_value: int = 0, |
| 138 | + q_group_size: int = 64, residual_length: int = 128): |
| 139 | + self.backend = backend |
| 140 | + self.nbits = nbits |
| 141 | + self.axis_key = axis_key |
| 142 | + self.axis_value = axis_value |
| 143 | + self.q_group_size = q_group_size |
| 144 | + self.residual_length = residual_length |
| 145 | + |
| 146 | +# Compatibility function for removed _crop_past_key_values |
| 147 | +def _crop_past_key_values(model, past_key_values, max_length): |
| 148 | + """ |
| 149 | + Crop past key values to a maximum length. |
| 150 | + This is a compatibility function for the removed _crop_past_key_values. |
| 151 | + """ |
| 152 | + if past_key_values is None: |
| 153 | + return past_key_values |
| 154 | + |
| 155 | + # If past_key_values is a Cache object |
| 156 | + if hasattr(past_key_values, 'crop'): |
| 157 | + return past_key_values.crop(max_length) |
| 158 | + |
| 159 | + # If it's a tuple of tensors (legacy format) |
| 160 | + if isinstance(past_key_values, tuple): |
| 161 | + cropped_past_key_values = [] |
| 162 | + for layer_past_key_values in past_key_values: |
| 163 | + if isinstance(layer_past_key_values, tuple) and len(layer_past_key_values) == 2: |
| 164 | + # Standard format: (key, value) |
| 165 | + key, value = layer_past_key_values |
| 166 | + if key.shape[-2] > max_length: |
| 167 | + key = key[..., :max_length, :] |
| 168 | + if value.shape[-2] > max_length: |
| 169 | + value = value[..., :max_length, :] |
| 170 | + cropped_past_key_values.append((key, value)) |
| 171 | + else: |
| 172 | + # Other formats, just append as is |
| 173 | + cropped_past_key_values.append(layer_past_key_values) |
| 174 | + return tuple(cropped_past_key_values) |
| 175 | + |
| 176 | + # For other cache types, return as is |
| 177 | + return past_key_values |
| 178 | + |
114 | 179 | if is_accelerate_available(): |
115 | 180 | from accelerate.hooks import AlignDevicesHook, add_hook_to_module |
116 | 181 |
|
@@ -1002,7 +1067,8 @@ def _get_logits_processor( |
1002 | 1067 | device=device, |
1003 | 1068 | ) |
1004 | 1069 | ) |
1005 | | - if generation_config.forced_decoder_ids is not None: |
| 1070 | + # Compatibility with transformers 4.57.1+: forced_decoder_ids has been removed |
| 1071 | + if hasattr(generation_config, 'forced_decoder_ids') and generation_config.forced_decoder_ids is not None: |
1006 | 1072 | # TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT |
1007 | 1073 | raise ValueError( |
1008 | 1074 | "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " |
|
0 commit comments