Skip to content

Commit 4d45b48

Browse files
committed
fixes
1 parent b687a15 commit 4d45b48

File tree

3 files changed

+88
-6
lines changed

3 files changed

+88
-6
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.7.6
55
+++++
66

7-
* :pr:`192`: add support for Gemma-3, add serialization for HybridCache
7+
* :pr:`192`: add support for Gemma-3, add serialization for HybridCache,
8+
changes to support ``transformers>=4.54``
89

910
0.7.5
1011
+++++

_unittests/ut_torch_models/test_tiny_llms_bypassed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def debug():
5252
strict=False,
5353
)
5454
got = ep.module()(**inputs)
55-
self.assertEqualArrayAny(expected, got)
55+
self.assertEqualArrayAny(expected, got, atol=1e-5)
5656

5757
@ignore_warnings(UserWarning)
5858
def test_export_phi2_2_bypassed(self):

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def make_hybrid_cache(
362362
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
363363
max_cache_len: Optional[int] = None,
364364
max_batch_size: Optional[int] = None,
365+
sliding_window: Optional[int] = None,
365366
) -> transformers.cache_utils.HybridCache:
366367
"""
367368
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
@@ -392,30 +393,110 @@ def make_hybrid_cache(
392393
]
393394
)
394395
print(string_type(past_key_values, with_shape=True))
396+
397+
This part defines how the shapes are working in one HybridCache.
398+
399+
.. code-block:: python
400+
401+
self.max_cache_len = (
402+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
403+
404+
# Sliding layers can't be larger than the overall max cache len
405+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
406+
self.max_batch_size = max_batch_size
407+
408+
self.head_dim = (
409+
config.head_dim if hasattr(config, "head_dim")
410+
else config.hidden_size // config.num_attention_heads
411+
)
412+
413+
self._dtype = dtype
414+
self.num_key_value_heads = (
415+
config.num_attention_heads
416+
if getattr(config, "num_key_value_heads", None) is None
417+
else config.num_key_value_heads
418+
)
419+
420+
# If the attribute does not exist in the config, fallback to a simple StaticCache
421+
if hasattr(config, "layer_types"):
422+
self.is_sliding = [
423+
layer_type != "full_attention" for layer_type in config.layer_types]
424+
else:
425+
self.is_sliding = [False] * config.num_hidden_layers
426+
427+
self.key_cache: list[torch.Tensor] = []
428+
self.value_cache: list[torch.Tensor] = []
429+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
430+
self.max_cache_len, self.head_dim)
431+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
432+
self.sliding_window_len, self.head_dim)
433+
self.sliding_window = min(config.sliding_window, max_cache_len)
434+
device = torch.device(device) if device is not None else None
435+
for i in range(config.num_hidden_layers):
436+
layer_device = layer_device_map[i] if layer_device_map is not None else device
437+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
438+
new_layer_key_cache = torch.zeros(
439+
cache_shape, dtype=self._dtype, device=layer_device)
440+
new_layer_value_cache = torch.zeros(
441+
cache_shape, dtype=self._dtype, device=layer_device)
442+
torch._dynamo.mark_static_address(new_layer_key_cache)
443+
torch._dynamo.mark_static_address(new_layer_value_cache)
444+
self.key_cache.append(new_layer_key_cache)
445+
self.value_cache.append(new_layer_value_cache)
395446
"""
447+
layer_types = None
396448
if key_value_pairs:
397449
assert (
398450
not max_batch_size and not max_cache_len
399451
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
400452
max_batch_size = key_value_pairs[0][0].shape[0]
401-
max_cache_len = key_value_pairs[0][0].shape[2]
453+
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
454+
if len(sets_of_dim) == 1:
455+
max_cache_len = sets_of_dim.pop()
456+
sliding_window = max_cache_len
457+
else:
458+
assert (
459+
len(sets_of_dim) == 2
460+
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
461+
max_cache_len = max(sets_of_dim)
462+
sliding_window = min(sets_of_dim)
463+
layer_types = [
464+
"full_attention" if i == max_cache_len else "sliding_attention"
465+
for i in [kv[0].shape[2] for kv in key_value_pairs]
466+
]
402467
else:
403468
assert (
404469
max_batch_size and max_cache_len
405470
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
406-
_ = max_cache_len
471+
if sliding_window is None:
472+
sliding_window = max_cache_len
473+
_max_cache_len = max_cache_len
474+
_sliding_window = sliding_window
407475

408476
class _config:
409-
max_cache_len = _
477+
max_cache_len = _max_cache_len
410478
batch_size = max_batch_size
411479
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
412480
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
413481
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
414482
num_hidden_layers = len(key_value_pairs)
483+
sliding_window = _sliding_window
484+
485+
if layer_types:
486+
_config.layer_types = layer_types
415487

416488
cache = transformers.cache_utils.HybridCache(
417489
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
418490
)
419491
for i, (key, value) in enumerate(key_value_pairs):
420-
cache.update(key, value, i)
492+
cache.update(
493+
key,
494+
value,
495+
i,
496+
cache_kwargs={
497+
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
498+
key.device
499+
)
500+
},
501+
)
421502
return cache

0 commit comments

Comments
 (0)