@@ -55,6 +55,12 @@ def _get_invalid_idx_value(cls):
5555
5656
5757class QEffDynamicLayer (DynamicLayer ):
58+ def lazy_initialization (self , key_states : torch .Tensor ):
59+ self .dtype , self .device = key_states .dtype , key_states .device
60+ self .keys = torch .tensor ([], dtype = self .dtype , device = self .device )
61+ self .values = torch .tensor ([], dtype = self .dtype , device = self .device )
62+ self .is_initialized = True
63+
5864 def read_only (self , cache_kwargs ):
5965 """
6066 Reads the `key_states` and `value_states` for the layer.
@@ -151,6 +157,7 @@ def write_only(self, key_states, value_states, cache_kwargs):
151157 self .keys = key_states
152158 self .values = value_states
153159 else :
160+ # breakpoint()
154161 position_ids = cache_kwargs .get ("position_ids" )
155162 batch_index = cache_kwargs .get ("batch_index" , None ) # Check and fetch batch index value form the kwargs
156163
@@ -185,11 +192,15 @@ def update(
185192 Return:
186193 A tuple containing the updated key and value states.
187194 """
195+ # breakpoint()
188196 # Update the cache
197+ # if not self.is_initialized:
198+
189199 if self .keys is None :
190200 self .keys = key_states
191201 self .values = value_states
192202 k_out , v_out = self .keys , self .values
203+ self .is_initialized = True
193204 else :
194205 position_ids = cache_kwargs .get ("position_ids" )
195206 batch_index = cache_kwargs .get ("batch_index" , None ) # Check and fetch batch index value form the kwargs
@@ -306,15 +317,48 @@ class QEffDynamicCache(DynamicCache):
306317
307318 """
308319
309- def __init__ (self , ddp_cache_data : Optional [Iterable [tuple [torch .Tensor , torch .Tensor ]]] = None , * args , ** kwargs ):
320+ def __init__ (
321+ self ,
322+ ddp_cache_data : Optional [Iterable [tuple [torch .Tensor , torch .Tensor ]]] = None ,
323+ config = None ,
324+ offloading : bool = False ,
325+ offload_only_non_sliding : bool = False ,
326+ * args ,
327+ ** kwargs ,
328+ ):
310329 # Remove layer_classes if present to avoid duplicate argument
311- kwargs .pop ("layer_classes" , None )
330+ # breakpoint()
331+ kwargs .pop ("layers" , None )
312332 from transformers .cache_utils import Cache # Import here to avoid circular import
313333
314- Cache .__init__ (self , layer_classes = QEffDynamicLayer , * args , ** kwargs )
334+ # breakpoint()
335+ layers = []
336+ # If a config is passed, use it to infer the layer types and initialize accordingly
337+ if len (layers ) == 0 :
338+ Cache .__init__ (
339+ self ,
340+ layer_class_to_replicate = QEffDynamicLayer ,
341+ offloading = offloading ,
342+ offload_only_non_sliding = offload_only_non_sliding ,
343+ # args=args,
344+ # kwargs=kwargs,
345+ )
346+ else :
347+ Cache .__init__ (
348+ self ,
349+ layers = layers ,
350+ offloading = offloading ,
351+ offload_only_non_sliding = offload_only_non_sliding ,
352+ # args=args,
353+ # kwargs=kwargs,
354+ )
355+
315356 if ddp_cache_data is not None :
316- for key_states , value_states in ddp_cache_data :
317- self .layers .append (QEffDynamicLayer .from_tensors (key_states , value_states ))
357+ for layer_idx , (key_states , value_states ) in enumerate (ddp_cache_data ):
358+ # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
359+ layers .append (QEffDynamicLayer ())
360+ # Update the layer with the data
361+ _ , _ = layers [layer_idx ].update (key_states , value_states )
318362
319363 def read_only (self , layer_idx , cache_kwargs ):
320364 """
@@ -329,6 +373,7 @@ def read_only(self, layer_idx, cache_kwargs):
329373 Return:
330374 A tuple containing the updated key and value states.
331375 """
376+ # breakpoint()
332377 return self .layers [layer_idx ].read_only (cache_kwargs )
333378
334379 def read_only_blockedKV (self , start_index , end_index , layer_idx , cache_kwargs ):
@@ -394,6 +439,18 @@ def update3D(
394439 self .append_new_layers (layer_idx )
395440 return self .layers [layer_idx ].update3D (key_states , value_states , cache_kwargs )
396441
442+ # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
443+ # """Returns the sequence length of the cached states. A layer index can be optionally passed."""
444+ # # TODO: deprecate this function in favor of `cache_position`
445+ # breakpoint()
446+ # is_empty_layer = (
447+ # len(self.key_cache) == 0 # no cache in any layer
448+ # or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
449+ # or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
450+ # )
451+ # layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
452+ # return layer_seq_length
453+
397454
398455class QEffEncoderDecoderCache (EncoderDecoderCache ):
399456 """
0 commit comments