Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def test_ngram_correctness(


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
["model_setup", "mm_enabled", "chunked_prefill_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
pytest.param(
(
"eagle3",
Expand All @@ -135,6 +135,7 @@ def test_ngram_correctness(
1,
),
False,
False,
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
Expand All @@ -147,6 +148,7 @@ def test_ngram_correctness(
1,
),
False,
True,
),
(
(
Expand All @@ -156,6 +158,7 @@ def test_ngram_correctness(
1,
),
False,
False,
),
pytest.param(
(
Expand All @@ -165,6 +168,7 @@ def test_ngram_correctness(
4,
),
False,
False,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
Expand All @@ -175,6 +179,7 @@ def test_ngram_correctness(
4,
),
True,
True,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
Expand All @@ -185,6 +190,7 @@ def test_ngram_correctness(
1,
),
False,
False,
),
],
ids=[
Expand All @@ -203,6 +209,7 @@ def test_eagle_correctness(
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
chunked_prefill_enabled: bool,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
Expand Down Expand Up @@ -239,9 +246,13 @@ def test_eagle_correctness(
m.setenv("VLLM_ROCM_USE_AITER", "1")

method, model_name, spec_model_name, tp_size = model_setup
max_model_len = 2048
max_num_batched_tokens = max_model_len
if chunked_prefill_enabled:
max_num_batched_tokens = 128

ref_llm = LLM(
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
Expand All @@ -256,9 +267,11 @@ def test_eagle_correctness(
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
"max_model_len": max_num_batched_tokens,
},
max_model_len=2048,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=chunked_prefill_enabled,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand Down