Skip to content

Commit ec8bcb9

Browse files
seven-milebenchislett
authored andcommitted
[Bugfix][Spec Decode] Fix wrong valid_mask for padded speculation when chunked prefill occurs (vllm-project#26231)
Signed-off-by: seven-mile <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]> Co-authored-by: Benjamin Chislett <[email protected]>
1 parent 87ff41e commit ec8bcb9

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,9 @@ def prepare_next_token_ids_padded(
522522
)
523523

524524
# Generate a mask for all valid tokens within those requests
525-
max_gen_len = sampled_token_ids.shape[-1]
526-
if max_gen_len == 1:
527-
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
528-
else:
529-
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
530-
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
531-
)
525+
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
526+
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
527+
)
532528

533529
# Count the number of valid tokens in each request
534530
valid_sampled_tokens_count = valid_mask.sum(dim=1)

0 commit comments

Comments
 (0)