Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def test_ngram_correctness(


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
["model_setup", "mm_enabled", "chunked_prefill_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
pytest.param(
(
"eagle3",
Expand All @@ -135,6 +135,7 @@ def test_ngram_correctness(
1,
),
False,
False,
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
Expand All @@ -147,6 +148,7 @@ def test_ngram_correctness(
1,
),
False,
True,
),
(
(
Expand All @@ -156,6 +158,7 @@ def test_ngram_correctness(
1,
),
False,
False,
),
pytest.param(
(
Expand All @@ -165,6 +168,7 @@ def test_ngram_correctness(
4,
),
False,
False,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
Expand All @@ -175,6 +179,7 @@ def test_ngram_correctness(
4,
),
True,
True,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
Expand All @@ -185,6 +190,7 @@ def test_ngram_correctness(
1,
),
False,
False,
),
],
ids=[
Expand All @@ -203,6 +209,7 @@ def test_eagle_correctness(
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
chunked_prefill_enabled: bool,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
Expand Down Expand Up @@ -239,9 +246,13 @@ def test_eagle_correctness(
m.setenv("VLLM_ROCM_USE_AITER", "1")

method, model_name, spec_model_name, tp_size = model_setup
max_model_len = 2048
max_num_batched_tokens = max_model_len
if chunked_prefill_enabled:
max_num_batched_tokens = 128

ref_llm = LLM(
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
Expand All @@ -256,9 +267,11 @@ def test_eagle_correctness(
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
"max_model_len": max_num_batched_tokens,
},
max_model_len=2048,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=chunked_prefill_enabled,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand Down
10 changes: 3 additions & 7 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,9 @@ def prepare_next_token_ids_padded(
)

# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
else:
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
Comment on lines +525 to +527
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This simplification is a great improvement. By removing the special case for max_gen_len == 1, the code is now more robust. The previous logic didn't account for discarded requests when max_gen_len == 1, which could lead to using an invalid token ID of -1. This unified approach correctly handles all cases.


# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
Expand Down