diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 9ed9cd7950a9..681933899d30 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -124,9 +124,9 @@ def test_ngram_correctness( @pytest.mark.parametrize( - ["model_setup", "mm_enabled"], + ["model_setup", "mm_enabled", "chunked_prefill_enabled"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), pytest.param( ( "eagle3", @@ -135,6 +135,7 @@ def test_ngram_correctness( 1, ), False, + False, marks=pytest.mark.skip( reason="Skipping due to its head_dim not being a a multiple of 32" ), @@ -147,6 +148,7 @@ def test_ngram_correctness( 1, ), False, + True, ), ( ( @@ -156,6 +158,7 @@ def test_ngram_correctness( 1, ), False, + False, ), pytest.param( ( @@ -165,6 +168,7 @@ def test_ngram_correctness( 4, ), False, + False, marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 pytest.param( @@ -175,6 +179,7 @@ def test_ngram_correctness( 4, ), True, + True, marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 ( @@ -185,6 +190,7 @@ def test_ngram_correctness( 1, ), False, + False, ), ], ids=[ @@ -203,6 +209,7 @@ def test_eagle_correctness( sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + chunked_prefill_enabled: bool, attn_backend: str, ): if attn_backend == "TREE_ATTN": @@ -239,9 +246,13 @@ def test_eagle_correctness( m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup + max_model_len = 2048 + max_num_batched_tokens = max_model_len + if chunked_prefill_enabled: + max_num_batched_tokens = 128 ref_llm = LLM( - model=model_name, max_model_len=2048, tensor_parallel_size=tp_size + model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -256,9 +267,11 @@ def test_eagle_correctness( "method": method, "model": spec_model_name, "num_speculative_tokens": 3, - "max_model_len": 2048, + "max_model_len": max_num_batched_tokens, }, - max_model_len=2048, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=chunked_prefill_enabled, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0