@@ -517,18 +517,16 @@ def __init__(
517
517
else :
518
518
self .tokenizer .add_special_tokens ({"pad_token" : "[PAD]" })
519
519
520
- # Perform a forward pass to determine the ordering of past key attention tensor dimensions
520
+ # Perform a forward pass to determine the structure of the past_key_values
521
521
one_token = torch .tensor ([[1 ]], device = inference_engine .get_device ())
522
522
_ , past_key_values , _ = self .forward (input_ids = one_token , attention_mask = one_token )
523
- pkv_tensors_per_layer = len (past_key_values [0 ])
524
- if pkv_tensors_per_layer == 2 :
523
+ if torch .is_tensor (past_key_values [0 ]):
524
+ self .batch_type = CombinedKVCausalLMBatch
525
+ else :
526
+ # check the ordering of the key tensor dimensions
525
527
key_past , value_past = past_key_values [0 ]
526
528
keys_head_dim_last = key_past .shape [- 1 ] == value_past .shape [- 1 ]
527
529
self .batch_type = CausalLMBatch if keys_head_dim_last else KeysDimTransposedCausalLMBatch
528
- elif pkv_tensors_per_layer == 1 :
529
- self .batch_type = CombinedKVCausalLMBatch
530
- else :
531
- raise ValueError ("Unexpected number of elements in past_key_values cache" )
532
530
533
531
@property
534
532
def batch_type (self ) -> Type [CausalLMBatch ]:
@@ -658,7 +656,8 @@ def generate_token(
658
656
# Trim attention mask and past kvs if we padded to multiple of 8. This is important to be able to
659
657
# generate up to the model's token limit.
660
658
batch .attention_mask = batch .attention_mask [:, left_pad :]
661
- if len (past [0 ]) == 1 :
659
+ # For a combined KV cache, past is a list of Tensors, not Tuples
660
+ if torch .is_tensor (past [0 ]):
662
661
for cache in past :
663
662
cache .data = cache .data [..., left_pad :, :]
664
663
else :
0 commit comments