|
1 | | -from typing import Any, Callable, Dict, List, Optional, Tuple |
| 1 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
2 | 2 | import packaging.version as pv |
3 | 3 | import torch |
4 | 4 | import transformers |
@@ -152,10 +152,18 @@ def make_dynamic_shapes_kv_cache( |
152 | 152 | return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)] |
153 | 153 |
|
154 | 154 |
|
| 155 | +def _preprocess_key_value_pairs( |
| 156 | + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], |
| 157 | +) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| 158 | + if not key_value_pairs or isinstance(key_value_pairs[0], tuple): |
| 159 | + return key_value_pairs |
| 160 | + return list(zip(key_value_pairs[::2], key_value_pairs[1::2])) |
| 161 | + |
| 162 | + |
155 | 163 | if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): |
156 | 164 |
|
157 | 165 | def make_dynamic_cache( |
158 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
| 166 | + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], |
159 | 167 | ) -> transformers.cache_utils.DynamicCache: |
160 | 168 | """ |
161 | 169 | Creates an instance of :class:`transformers.cache_utils.DynamicCache`. |
@@ -191,6 +199,7 @@ def make_dynamic_cache( |
191 | 199 | ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions |
192 | 200 | are supported. |
193 | 201 | """ |
| 202 | + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) |
194 | 203 | if ( |
195 | 204 | key_value_pairs |
196 | 205 | and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor) |
@@ -230,7 +239,7 @@ def make_dynamic_cache( |
230 | 239 | else: |
231 | 240 |
|
232 | 241 | def make_dynamic_cache( |
233 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
| 242 | + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], |
234 | 243 | ) -> transformers.cache_utils.DynamicCache: |
235 | 244 | """ |
236 | 245 | Creates an instance of :class:`transformers.cache_utils.DynamicCache`. |
@@ -262,14 +271,15 @@ def make_dynamic_cache( |
262 | 271 | ) |
263 | 272 | print(string_type(past_key_values, with_shape=True)) |
264 | 273 | """ |
| 274 | + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) |
265 | 275 | cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore |
266 | 276 | for i, (key, value) in enumerate(key_value_pairs): |
267 | 277 | cache.update(key, value, i) |
268 | 278 | return cache |
269 | 279 |
|
270 | 280 |
|
271 | 281 | def make_static_cache( |
272 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
| 282 | + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], |
273 | 283 | max_cache_len: Optional[int] = None, |
274 | 284 | ) -> transformers.cache_utils.DynamicCache: |
275 | 285 | """ |
@@ -302,6 +312,7 @@ def make_static_cache( |
302 | 312 | ) |
303 | 313 | print(string_type(past_key_values, with_shape=True)) |
304 | 314 | """ |
| 315 | + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) |
305 | 316 |
|
306 | 317 | class _config: |
307 | 318 | def __init__(self): |
@@ -444,9 +455,10 @@ def get_text_config(self, *args, **kwargs): |
444 | 455 |
|
445 | 456 |
|
446 | 457 | def make_sliding_window_cache( |
447 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
| 458 | + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], |
448 | 459 | ) -> transformers.cache_utils.SlidingWindowCache: |
449 | 460 | "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." |
| 461 | + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) |
450 | 462 |
|
451 | 463 | class _config: |
452 | 464 | def __init__(self): |
@@ -499,7 +511,7 @@ def get_text_config(self, *args, **kwargs): |
499 | 511 |
|
500 | 512 |
|
501 | 513 | def make_hybrid_cache( |
502 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
| 514 | + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], |
503 | 515 | max_cache_len: Optional[int] = None, |
504 | 516 | max_batch_size: Optional[int] = None, |
505 | 517 | sliding_window: Optional[int] = None, |
@@ -584,6 +596,7 @@ def make_hybrid_cache( |
584 | 596 | self.key_cache.append(new_layer_key_cache) |
585 | 597 | self.value_cache.append(new_layer_value_cache) |
586 | 598 | """ |
| 599 | + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) |
587 | 600 | layer_types = None |
588 | 601 | if key_value_pairs: |
589 | 602 | assert ( |
|
0 commit comments