|
10 | 10 | from transformers.cache_utils import MambaCache |
11 | 11 |
|
12 | 12 |
|
| 13 | +class CacheKeyValue: |
| 14 | + def __init__(self, cache: "Cache"): # noqa: F821 |
| 15 | + if hasattr(cache, "layers"): |
| 16 | + self.key_cache = [layer.keys for layer in cache.layers if layer.keys is not None] |
| 17 | + self.value_cache = [ |
| 18 | + layer.values for layer in cache.layers if layer.values is not None |
| 19 | + ] |
| 20 | + else: |
| 21 | + self.key_cache = cache.key_cache |
| 22 | + self.value_cache = cache.value_cache |
| 23 | + |
| 24 | + |
13 | 25 | def flatten_unflatten_for_dynamic_shapes( |
14 | 26 | obj: Any, |
15 | 27 | use_dict: bool = False, |
@@ -221,19 +233,20 @@ def __init__(self): |
221 | 233 | ), |
222 | 234 | ) |
223 | 235 | cache = transformers.cache_utils.StaticCache( |
224 | | - _config(), |
| 236 | + config=_config(), |
225 | 237 | max_batch_size=key_value_pairs[0][0].shape[0], |
226 | 238 | device=key_value_pairs[0][0].device, |
227 | 239 | dtype=key_value_pairs[0][0].dtype, |
228 | 240 | max_cache_len=max_cache_len, |
229 | 241 | ) |
| 242 | + ca = CacheKeyValue(cache) |
230 | 243 | for i in range(len(key_value_pairs)): |
231 | 244 | assert ( |
232 | 245 | key_value_pairs[i][0].shape == key_value_pairs[i][1].shape |
233 | 246 | ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}" |
234 | 247 | d = key_value_pairs[i][1].shape[2] |
235 | | - cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0] |
236 | | - cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1] |
| 248 | + ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0] |
| 249 | + ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1] |
237 | 250 | return cache |
238 | 251 |
|
239 | 252 |
|
@@ -300,23 +313,24 @@ def __init__(self): |
300 | 313 | self.sliding_window = key_value_pairs[0][0].shape[2] |
301 | 314 |
|
302 | 315 | cache = transformers.cache_utils.SlidingWindowCache( |
303 | | - _config(), |
| 316 | + config=_config(), |
304 | 317 | max_batch_size=key_value_pairs[0][0].shape[0], |
305 | 318 | max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window |
306 | 319 | device=key_value_pairs[0][0].device, |
307 | 320 | dtype=key_value_pairs[0][0].dtype, |
308 | 321 | ) |
| 322 | + ca = CacheKeyValue(cache) |
309 | 323 | for i in range(len(key_value_pairs)): |
310 | | - assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, ( |
| 324 | + assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, ( |
311 | 325 | f"Shape mismatch, expected {cache.key_cache[i].shape}, " |
312 | 326 | f"got {key_value_pairs[i][0].shape}" |
313 | 327 | ) |
314 | | - cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] |
315 | | - assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, ( |
| 328 | + ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0] |
| 329 | + assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, ( |
316 | 330 | f"Shape mismatch, expected {cache.value_cache[i].shape}, " |
317 | 331 | f"got {key_value_pairs[i][1].shape}" |
318 | 332 | ) |
319 | | - cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] |
| 333 | + ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1] |
320 | 334 | return cache |
321 | 335 |
|
322 | 336 |
|
@@ -373,9 +387,10 @@ class _config: |
373 | 387 | num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None |
374 | 388 | head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None |
375 | 389 | num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None |
| 390 | + num_hidden_layers = len(key_value_pairs) |
376 | 391 |
|
377 | 392 | cache = transformers.cache_utils.HybridCache( |
378 | | - _config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size |
| 393 | + config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size |
379 | 394 | ) |
380 | 395 | for i, (key, value) in enumerate(key_value_pairs): |
381 | 396 | cache.update(key, value, i) |
|
0 commit comments