@@ -168,11 +168,15 @@ def make_dynamic_cache(
168168 ]
169169 )
170170 print(string_type(past_key_values, with_shape=True))
171+
172+ The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
173+ ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
174+ are supported.
171175 """
172176 if (
173177 key_value_pairs
174178 and isinstance (key_value_pairs [0 ][0 ], torch ._subclasses .fake_tensor .FakeTensor )
175- and pv .Version (transformers .__version__ ) >= pv .Version ("4.55 " )
179+ and pv .Version (transformers .__version__ ) >= pv .Version ("4.56 " )
176180 ):
177181 cache = transformers .cache_utils .DynamicCache ()
178182 cache .layers .extend (
@@ -516,51 +520,51 @@ def make_hybrid_cache(
516520
517521 .. code-block:: python
518522
519- self.max_cache_len = (
520- max_cache_len if max_cache_len is not None else config.max_position_embeddings)
523+ self.max_cache_len = (
524+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
521525
522- # Sliding layers can't be larger than the overall max cache len
523- self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
524- self.max_batch_size = max_batch_size
526+ # Sliding layers can't be larger than the overall max cache len
527+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
528+ self.max_batch_size = max_batch_size
525529
526- self.head_dim = (
527- config.head_dim if hasattr(config, "head_dim")
528- else config.hidden_size // config.num_attention_heads
529- )
530+ self.head_dim = (
531+ config.head_dim if hasattr(config, "head_dim")
532+ else config.hidden_size // config.num_attention_heads
533+ )
530534
531- self._dtype = dtype
532- self.num_key_value_heads = (
533- config.num_attention_heads
534- if getattr(config, "num_key_value_heads", None) is None
535- else config.num_key_value_heads
536- )
535+ self._dtype = dtype
536+ self.num_key_value_heads = (
537+ config.num_attention_heads
538+ if getattr(config, "num_key_value_heads", None) is None
539+ else config.num_key_value_heads
540+ )
537541
538- # If the attribute does not exist in the config, fallback to a simple StaticCache
539- if hasattr(config, "layer_types"):
540- self.is_sliding = [
541- layer_type != "full_attention" for layer_type in config.layer_types]
542- else:
543- self.is_sliding = [False] * config.num_hidden_layers
544-
545- self.key_cache: list[torch.Tensor] = []
546- self.value_cache: list[torch.Tensor] = []
547- global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
548- self.max_cache_len, self.head_dim)
549- sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
550- self.sliding_window_len, self.head_dim)
551- self.sliding_window = min(config.sliding_window, max_cache_len)
552- device = torch.device(device) if device is not None else None
553- for i in range(config.num_hidden_layers):
554- layer_device = layer_device_map[i] if layer_device_map is not None else device
555- cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
556- new_layer_key_cache = torch.zeros(
557- cache_shape, dtype=self._dtype, device=layer_device)
558- new_layer_value_cache = torch.zeros(
559- cache_shape, dtype=self._dtype, device=layer_device)
560- torch._dynamo.mark_static_address(new_layer_key_cache)
561- torch._dynamo.mark_static_address(new_layer_value_cache)
562- self.key_cache.append(new_layer_key_cache)
563- self.value_cache.append(new_layer_value_cache)
542+ # If the attribute does not exist in the config, fallback to a simple StaticCache
543+ if hasattr(config, "layer_types"):
544+ self.is_sliding = [
545+ layer_type != "full_attention" for layer_type in config.layer_types]
546+ else:
547+ self.is_sliding = [False] * config.num_hidden_layers
548+
549+ self.key_cache: list[torch.Tensor] = []
550+ self.value_cache: list[torch.Tensor] = []
551+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
552+ self.max_cache_len, self.head_dim)
553+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
554+ self.sliding_window_len, self.head_dim)
555+ self.sliding_window = min(config.sliding_window, max_cache_len)
556+ device = torch.device(device) if device is not None else None
557+ for i in range(config.num_hidden_layers):
558+ layer_device = layer_device_map[i] if layer_device_map is not None else device
559+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
560+ new_layer_key_cache = torch.zeros(
561+ cache_shape, dtype=self._dtype, device=layer_device)
562+ new_layer_value_cache = torch.zeros(
563+ cache_shape, dtype=self._dtype, device=layer_device)
564+ torch._dynamo.mark_static_address(new_layer_key_cache)
565+ torch._dynamo.mark_static_address(new_layer_value_cache)
566+ self.key_cache.append(new_layer_key_cache)
567+ self.value_cache.append(new_layer_value_cache)
564568 """
565569 layer_types = None
566570 if key_value_pairs :
0 commit comments