Skip to content

Commit 45bfd01

Browse files
tjohnson31415njhill
authored andcommitted
fix: use explicit is_tensor check (whoops!)
len() is valid for a tensor and returns the size of the first dimension Signed-off-by: Travis Johnson <[email protected]>
1 parent b4bd29d commit 45bfd01

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,18 +517,16 @@ def __init__(
517517
else:
518518
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
519519

520-
# Perform a forward pass to determine the ordering of past key attention tensor dimensions
520+
# Perform a forward pass to determine the structure of the past_key_values
521521
one_token = torch.tensor([[1]], device=inference_engine.get_device())
522522
_, past_key_values, _ = self.forward(input_ids=one_token, attention_mask=one_token)
523-
pkv_tensors_per_layer = len(past_key_values[0])
524-
if pkv_tensors_per_layer == 2:
523+
if torch.is_tensor(past_key_values[0]):
524+
self.batch_type = CombinedKVCausalLMBatch
525+
else:
526+
# check the ordering of the key tensor dimensions
525527
key_past, value_past = past_key_values[0]
526528
keys_head_dim_last = key_past.shape[-1] == value_past.shape[-1]
527529
self.batch_type = CausalLMBatch if keys_head_dim_last else KeysDimTransposedCausalLMBatch
528-
elif pkv_tensors_per_layer == 1:
529-
self.batch_type = CombinedKVCausalLMBatch
530-
else:
531-
raise ValueError("Unexpected number of elements in past_key_values cache")
532530

533531
@property
534532
def batch_type(self) -> Type[CausalLMBatch]:
@@ -658,7 +656,8 @@ def generate_token(
658656
# Trim attention mask and past kvs if we padded to multiple of 8. This is important to be able to
659657
# generate up to the model's token limit.
660658
batch.attention_mask = batch.attention_mask[:, left_pad:]
661-
if len(past[0]) == 1:
659+
# For a combined KV cache, past is a list of Tensors, not Tuples
660+
if torch.is_tensor(past[0]):
662661
for cache in past:
663662
cache.data = cache.data[..., left_pad:, :]
664663
else:

0 commit comments

Comments
 (0)