Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions tests/v1/e2e/test_async_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_without_spec_decoding(
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)


def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
def test_with_eagle3_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
Expand Down Expand Up @@ -106,6 +106,42 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])


def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test ngram_gpu speculative decoding with different configurations.

This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""

# Variant with larger speculation window
ngram_gpu_config = {
"method": "ngram_gpu",
"num_speculative_tokens": 3,
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
}

# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", False, ngram_gpu_config, False),
(True, "mp", False, ngram_gpu_config, True),
(False, "mp", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, False),
(True, "uni", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, True),
]

# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests(monkeypatch, MODEL, test_configs, [{}])


@dynamo_config.patch(cache_size_limit=16)
def run_tests(
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -217,18 +253,19 @@ def run_test(
else dict(gpu_memory_utilization=0.9)
)
spec_mml = (spec_config or {}).get("max_model_len")
spec_method = (spec_config or {}).get("method", "none")
test_config = (
f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
)
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
with VllmRunner(
model,
max_model_len=512,
max_model_len=4096,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
Expand Down
9 changes: 7 additions & 2 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
"pangu_ultra_moe_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
NgramGPUTypes,
]


Expand Down Expand Up @@ -260,6 +262,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "ngram_gpu":
self.model = "ngram_gpu"
elif self.method == "suffix":
self.model = "suffix"
else:
Expand All @@ -274,9 +278,10 @@ def __post_init__(self):
):
self.method = "ngram"

if self.method in ("ngram", "[ngram]"):
if self.method in ("ngram", "[ngram]", "ngram_gpu"):
# Unified to "ngram" internally
self.method = "ngram"
if self.method in ("ngram", "[ngram]"):
self.method = "ngram"
# Set default values if not provided
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
Expand Down
8 changes: 5 additions & 3 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic.dataclasses import dataclass

import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.config.speculative import EagleModelTypes, NgramGPUTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
Expand Down Expand Up @@ -378,10 +378,12 @@ def __post_init__(self):
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None:
if self.speculative_config.method not in get_args(EagleModelTypes):
if self.speculative_config.method not in get_args(
EagleModelTypes
) and self.speculative_config.method not in get_args(NgramGPUTypes):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding"
"with EAGLE/MTP/NGram GPU kind of speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
Expand Down
Loading
Loading