@@ -362,6 +362,7 @@ def make_hybrid_cache(
362362 key_value_pairs : List [Tuple [torch .Tensor , torch .Tensor ]],
363363 max_cache_len : Optional [int ] = None ,
364364 max_batch_size : Optional [int ] = None ,
365+ sliding_window : Optional [int ] = None ,
365366) -> transformers .cache_utils .HybridCache :
366367 """
367368 Creates an instance of :class:`transformers.cache_utils.HybridCache`.
@@ -392,30 +393,110 @@ def make_hybrid_cache(
392393 ]
393394 )
394395 print(string_type(past_key_values, with_shape=True))
396+
397+ This part defines how the shapes are working in one HybridCache.
398+
399+ .. code-block:: python
400+
401+ self.max_cache_len = (
402+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
403+
404+ # Sliding layers can't be larger than the overall max cache len
405+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
406+ self.max_batch_size = max_batch_size
407+
408+ self.head_dim = (
409+ config.head_dim if hasattr(config, "head_dim")
410+ else config.hidden_size // config.num_attention_heads
411+ )
412+
413+ self._dtype = dtype
414+ self.num_key_value_heads = (
415+ config.num_attention_heads
416+ if getattr(config, "num_key_value_heads", None) is None
417+ else config.num_key_value_heads
418+ )
419+
420+ # If the attribute does not exist in the config, fallback to a simple StaticCache
421+ if hasattr(config, "layer_types"):
422+ self.is_sliding = [
423+ layer_type != "full_attention" for layer_type in config.layer_types]
424+ else:
425+ self.is_sliding = [False] * config.num_hidden_layers
426+
427+ self.key_cache: list[torch.Tensor] = []
428+ self.value_cache: list[torch.Tensor] = []
429+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
430+ self.max_cache_len, self.head_dim)
431+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
432+ self.sliding_window_len, self.head_dim)
433+ self.sliding_window = min(config.sliding_window, max_cache_len)
434+ device = torch.device(device) if device is not None else None
435+ for i in range(config.num_hidden_layers):
436+ layer_device = layer_device_map[i] if layer_device_map is not None else device
437+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
438+ new_layer_key_cache = torch.zeros(
439+ cache_shape, dtype=self._dtype, device=layer_device)
440+ new_layer_value_cache = torch.zeros(
441+ cache_shape, dtype=self._dtype, device=layer_device)
442+ torch._dynamo.mark_static_address(new_layer_key_cache)
443+ torch._dynamo.mark_static_address(new_layer_value_cache)
444+ self.key_cache.append(new_layer_key_cache)
445+ self.value_cache.append(new_layer_value_cache)
395446 """
447+ layer_types = None
396448 if key_value_pairs :
397449 assert (
398450 not max_batch_size and not max_cache_len
399451 ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
400452 max_batch_size = key_value_pairs [0 ][0 ].shape [0 ]
401- max_cache_len = key_value_pairs [0 ][0 ].shape [2 ]
453+ sets_of_dim = set (kv [0 ].shape [2 ] for kv in key_value_pairs )
454+ if len (sets_of_dim ) == 1 :
455+ max_cache_len = sets_of_dim .pop ()
456+ sliding_window = max_cache_len
457+ else :
458+ assert (
459+ len (sets_of_dim ) == 2
460+ ), f"Not implemented for more than 2 dimensions { sets_of_dim } "
461+ max_cache_len = max (sets_of_dim )
462+ sliding_window = min (sets_of_dim )
463+ layer_types = [
464+ "full_attention" if i == max_cache_len else "sliding_attention"
465+ for i in [kv [0 ].shape [2 ] for kv in key_value_pairs ]
466+ ]
402467 else :
403468 assert (
404469 max_batch_size and max_cache_len
405470 ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
406- _ = max_cache_len
471+ if sliding_window is None :
472+ sliding_window = max_cache_len
473+ _max_cache_len = max_cache_len
474+ _sliding_window = sliding_window
407475
408476 class _config :
409- max_cache_len = _
477+ max_cache_len = _max_cache_len
410478 batch_size = max_batch_size
411479 num_heads = key_value_pairs [0 ][0 ].shape [1 ] if key_value_pairs else None
412480 head_dim = key_value_pairs [0 ][0 ].shape [- 1 ] if key_value_pairs else None
413481 num_attention_heads = key_value_pairs [0 ][1 ].shape [1 ] if key_value_pairs else None
414482 num_hidden_layers = len (key_value_pairs )
483+ sliding_window = _sliding_window
484+
485+ if layer_types :
486+ _config .layer_types = layer_types
415487
416488 cache = transformers .cache_utils .HybridCache (
417489 config = _config (), max_cache_len = max_cache_len , max_batch_size = max_batch_size
418490 )
419491 for i , (key , value ) in enumerate (key_value_pairs ):
420- cache .update (key , value , i )
492+ cache .update (
493+ key ,
494+ value ,
495+ i ,
496+ cache_kwargs = {
497+ "cache_position" : torch .arange (0 , key .shape [2 ], dtype = torch .int64 ).to (
498+ key .device
499+ )
500+ },
501+ )
421502 return cache
0 commit comments