diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 00d93e1ba0b5..33e562a31096 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -73,7 +73,7 @@ def test_without_spec_decoding( run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): +def test_with_eagle3_spec_decoding(monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, spec decoding model length. @@ -106,6 +106,42 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) +def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test ngram_gpu speculative decoding with different configurations. + + This test specifically validates ngram_gpu behavior with various: + - Number of speculative tokens (2-6) + - Prompt lookup window sizes (min/max) + - Async scheduling enabled (as in production) + - Different executors and chunking settings + """ + + # Variant with larger speculation window + ngram_gpu_config = { + "method": "ngram_gpu", + "num_speculative_tokens": 3, + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + } + + # Test configurations covering various scenarios + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", False, ngram_gpu_config, False), + (True, "mp", False, ngram_gpu_config, True), + (False, "mp", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, False), + (True, "uni", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, True), + ] + + # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight + # and ngram_gpu doesn't require a specific draft model + run_tests(monkeypatch, MODEL, test_configs, [{}]) + + @dynamo_config.patch(cache_size_limit=16) def run_tests( monkeypatch: pytest.MonkeyPatch, @@ -217,18 +253,19 @@ def run_test( else dict(gpu_memory_utilization=0.9) ) spec_mml = (spec_config or {}).get("max_model_len") + spec_method = (spec_config or {}).get("method", "none") test_config = ( f"executor={executor}, preemption={test_preemption}, " f"async_sched={async_scheduling}, " f"chunk_prefill={test_prefill_chunking}, " - f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" + f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}" ) print("-" * 80) print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) with VllmRunner( model, - max_model_len=512, + max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a0c65b6049e1..d924e9f1c991 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -40,6 +40,7 @@ "pangu_ultra_moe_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +NgramGPUTypes = Literal["ngram_gpu"] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -47,6 +48,7 @@ "draft_model", "suffix", EagleModelTypes, + NgramGPUTypes, ] @@ -260,6 +262,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "ngram_gpu": + self.model = "ngram_gpu" elif self.method == "suffix": self.model = "suffix" else: @@ -274,9 +278,10 @@ def __post_init__(self): ): self.method = "ngram" - if self.method in ("ngram", "[ngram]"): + if self.method in ("ngram", "[ngram]", "ngram_gpu"): # Unified to "ngram" internally - self.method = "ngram" + if self.method in ("ngram", "[ngram]"): + self.method = "ngram" # Set default values if not provided if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d64e315b4fe3..3c883079a6b3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -21,7 +21,7 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs -from vllm.config.speculative import EagleModelTypes +from vllm.config.speculative import EagleModelTypes, NgramGPUTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -378,10 +378,12 @@ def __post_init__(self): # Currently, async scheduling only support eagle speculative # decoding. if self.speculative_config is not None: - if self.speculative_config.method not in get_args(EagleModelTypes): + if self.speculative_config.method not in get_args( + EagleModelTypes + ) and self.speculative_config.method not in get_args(NgramGPUTypes): raise ValueError( "Currently, async scheduling is only supported " - "with EAGLE/MTP kind of speculative decoding" + "with EAGLE/MTP/NGram GPU kind of speculative decoding" ) if self.speculative_config.disable_padded_drafter_batch: raise ValueError( diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py new file mode 100644 index 000000000000..47f880e10c7a --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPU-accelerated N-gram proposer using fully async PyTorch tensor operations. + +This version uses a fully vectorized approach with unfold and argmax for +finding the first match across all sequences in parallel. +""" +import torch +from torch import nn +import numpy as np + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, VllmConfig +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, +) +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +from vllm.config import set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.v1.utils import CpuGpuBuffer + +@support_torch_compile( + dynamic_arg_dims={ + "num_tokens_no_spec": 0, # batch dimension is dynamic + "token_ids_gpu": [0, 1], # both batch and sequence length are dynamic + "sampled_flags": 0, # batch dimension is dynamic + "valid_mask": 0, # batch dimension is dynamic + } +) +class NgramGPUKernel(nn.Module): + """ + GPU-accelerated N-gram proposer using fully async tensor operations. + + Interface: All inputs are GPU tensors (no lists, no numpy arrays) + + PERFORMANCE OPTIMIZATION WITH TORCH.COMPILE: + + 1. Tensor Allocation Strategy: + - DO: Allocate tensors inside forward() - torch.compile will optimize this + - DON'T: Pre-allocate buffers as class attributes - breaks compilation + - WHY: torch.compile fuses allocations into the compiled graph for efficiency + + 2. Dynamic Shapes: + - Batch size (dim 0) and sequence length (dim 1) are marked as dynamic + - torch.compile generates specialized kernels for different shapes + - The first call with a new shape will trigger recompilation (cached) + + 3. Graph Compilation: + - Uses fullgraph=True mode for maximum optimization + - All operations are tensor-based (no Python loops or conditionals) + - The entire forward pass is compiled into a single CUDA graph + + 4. Memory Efficiency: + - torch.compile's memory planning optimizes temporary allocations + - Fusion of operations reduces memory bandwidth requirements + - No manual memory management needed - compiler handles it + """ + + def __init__( + self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda" + ): + super().__init__() + + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.device = device + + def _find_first_and_extract_all_n_parallel( + self, + data: torch.Tensor, + seq_lengths: torch.Tensor, + min_pattern_len: int, + max_pattern_len: int, + result_len: int, + ) -> torch.Tensor: + """ + Process all pattern lengths in parallel, selecting the longest match. + Completely free of data-dependent control flow, suitable for + torch.compile optimization. + """ + batch_size = data.shape[0] + device = data.device + max_seq_len = data.shape[1] + num_patterns = max_pattern_len - min_pattern_len + 1 + + all_windows = data.unfold(1, max_pattern_len, 1) # [B, num_windows, max_n] + num_windows = all_windows.shape[1] + window_starts = torch.arange(num_windows, device=device) + + all_first_matches = torch.full( + (batch_size, num_patterns), -1, dtype=torch.long, device=device + ) + + for i, pattern_len in enumerate(range(min_pattern_len, max_pattern_len + 1)): + offset = max_pattern_len - pattern_len + + # Extract pattern from the end of each sequence + pattern_starts = seq_lengths - pattern_len + pattern_indices = pattern_starts.unsqueeze(1) + torch.arange( + pattern_len, device=device + ) + patterns = torch.gather(data, 1, pattern_indices.clamp(min=0)) + + # Slice windows and perform matching + current_windows = all_windows[..., offset:] + matches = (current_windows == patterns.unsqueeze(1)).all(dim=-1) + + # Validity check: ensure enough space for result extraction + max_valid_start = seq_lengths - pattern_len - result_len + valid_mask = window_starts <= max_valid_start.unsqueeze(1) + final_matches = matches & valid_mask + + # Find first match + # (if no match, argmax returns 0, but we verify with has_match) + first_indices = torch.argmax(final_matches.int(), dim=1) + has_match = final_matches[torch.arange(batch_size), first_indices] + + # Store valid match positions + all_first_matches[:, i] = torch.where(has_match, first_indices, -1) + + # Select the longest valid match, + # from back to front, prioritizing longer patterns + best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1) + best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back + + # Extract corresponding results + batch_idx = torch.arange(batch_size, device=device) + best_match_pos = all_first_matches[batch_idx, best_pattern_idx] + + # Handle matched cases - completely avoid data-dependent branching + has_any_match = best_match_pos >= 0 + + # Calculate result start positions, invalid positions will be + # clamped to valid range. Since all windows have size max_pattern_len, + # and patterns are matched at the END of windows (due to offset), + # the result starts after the full window + result_starts = torch.where( + has_any_match, + best_match_pos + max_pattern_len, + torch.zeros_like(best_match_pos), + ) + + # Create gather indices + result_indices = result_starts.unsqueeze(1) + torch.arange( + result_len, device=device + ) + # Ensure indices are within valid range + result_indices = result_indices.clamp(min=0, max=max_seq_len - 1) + + # Always execute gather (even for invalid data) + extracted_sequences = torch.gather(data, 1, result_indices) + + # Use where to zero out invalid results + results = torch.where( + has_any_match.unsqueeze(1), + extracted_sequences, + torch.zeros_like(extracted_sequences), + ) + + return results + + def forward( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU + token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU + sampled_flags: torch.Tensor, # [batch_size] bool on GPU + valid_mask: torch.Tensor, # [batch_size] bool on GPU + ) -> torch.Tensor: + """ + Forward pass for N-gram proposal using GPU tensor operations. + + This is the core computation method that will be compiled by torch.compile + via the @support_torch_compile decorator. + + Args: + num_tokens_no_spec: Number of tokens for each sequence [batch_size] + token_ids_gpu: Token IDs [batch_size, max_len] + sampled_flags: Whether each sequence has sampled tokens [batch_size] + valid_mask: Whether each sequence is valid for spec decode [batch_size] + + Returns: + draft_tokens: [batch_size, k] on GPU + """ + assert token_ids_gpu.device == self.device + assert num_tokens_no_spec.device == self.device + assert sampled_flags.device == self.device + assert valid_mask.device == self.device + + # Compute combined mask for valid sequences + combined_mask = ( + sampled_flags + & valid_mask + & (num_tokens_no_spec < self.max_model_len) + & (num_tokens_no_spec >= self.min_n) + ) + + batch_size = token_ids_gpu.size(0) + device = token_ids_gpu.device + + # Initialize output tensor - torch.compile will optimize this allocation + # NOTE(patchy): Do NOT pre-allocate this as a buffer + # it would break torch.compile + draft_tokens = torch.zeros( + (batch_size, self.k), dtype=torch.int32, device=device + ) + + results = self._find_first_and_extract_all_n_parallel( + token_ids_gpu, + num_tokens_no_spec, + min_pattern_len=self.min_n, + max_pattern_len=self.max_n, + result_len=self.k, + ) + + # Apply combined mask to results + draft_tokens = torch.where(combined_mask.unsqueeze(1), results, draft_tokens) + + return draft_tokens + + def load_model(self, *args, **kwargs): + """No model to load for N-gram proposer.""" + pass + +class NgramProposerGPU: + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + compilation_config = CompilationConfig( + level=3, + custom_ops=["none"], + splitting_ops=[], + compile_sizes=[], + inductor_compile_config={ + "enable_auto_functionalized_v2": False, + "max_autotune": True, + "aggressive_fusion": True, + "triton.autotune_pointwise": True, + "coordinate_descent_tuning": True, + "use_mixed_mm": False, + }, + use_cudagraph=False, + ) + + self.vllm_config = VllmConfig( + compilation_config=compilation_config + ) + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.device = device + + with set_current_vllm_config(self.vllm_config, check_compile=False): + self.kernel = NgramGPUKernel(vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device) + self.device = device + self.kernel.to(device) + self.kernel.eval() + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + + self._dummy_run() + + def _dummy_run(self): + with set_current_vllm_config(self.vllm_config, check_compile=False): + token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( + batch_size=self.max_num_seqs, + max_seq_len=min(self.max_model_len, 1024), + vocab_size=self.vocab_size, + pattern_len=self.k, + repetition_rate=0.5, + device=self.device + ) + + for _ in range(3): + with set_forward_context(None, self.vllm_config): + output = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask) + + def _generate_dummy_data( + self, + batch_size: int, + max_seq_len: int, + vocab_size: int = 152064, + pattern_len: int = 3, + repetition_rate: float = 0.5, + device: str = "cuda", + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate random test data with n-gram repetitions. + + Args: + batch_size: Number of sequences in the batch + max_seq_len: Maximum sequence length + vocab_size: Vocabulary size for random token generation + pattern_len: Length of patterns to inject for matching + repetition_rate: Rate of n-gram repetitions to inject + device: Device to place tensors on + + Returns: + token_ids: [batch_size, max_seq_len] tensor + num_tokens: [batch_size] tensor + sampled_flags: [batch_size] bool tensor + valid_mask: [batch_size] bool tensor + """ + # Generate random token IDs + token_ids = torch.randint( + 0, vocab_size, (batch_size, max_seq_len), + dtype=torch.int32, device=device + ) + + # Generate random sequence lengths + min_len = max(pattern_len * 2 + 3, max_seq_len // 2) + num_tokens = torch.randint( + min_len, max_seq_len, (batch_size,), + dtype=torch.int32, device=device + ) + + # Inject n-gram repetitions using the tail pattern of each sequence + for i in range(batch_size): + seq_len = num_tokens[i].item() + if seq_len > pattern_len * 2: + # Pattern is the last pattern_len tokens of the valid sequence + src_pos = seq_len - pattern_len + num_reps = int(seq_len * repetition_rate / pattern_len) + for _ in range(num_reps): + # Place the copied tail pattern somewhere before the tail + tgt_pos = torch.randint(0, seq_len - pattern_len, (1,)).item() + if tgt_pos == src_pos: + continue + + token_ids[i, tgt_pos:tgt_pos + pattern_len] = \ + token_ids[i, src_pos:src_pos + pattern_len].clone() + + # All sequences have sampled tokens and are valid + sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + return token_ids, num_tokens, sampled_flags, valid_mask + + def propose( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU + token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU + sampled_flags: torch.Tensor, # [batch_size] bool on GPU + valid_mask: torch.Tensor, # [batch_size] bool on GPU + ) -> torch.Tensor: + with set_current_vllm_config(self.vllm_config, check_compile=False): + with set_forward_context(None, self.vllm_config): + return self.kernel(num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask) + + def prepare_next_token_ids_cpu( + self, + sampled_token_ids: list[np.ndarray], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids.shape[0] > 0: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + return torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + # Batch convert seq_lens to avoid multiple .item() calls + seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() + + # Now use the pre-converted list to avoid .item() calls in the loop + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) + for i in range(num_reqs) + ] + ) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1 + ) + + # Generate a mask for all valid tokens within those requests + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) + + return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu + + def load_model(self, *args, **kwargs): + with set_current_vllm_config(self.vllm_config, check_compile=False): + self.kernel.load_model(*args, **kwargs) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7b4bc1d2a224..160a4372589f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -112,6 +112,9 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.token_ids_gpu_tensor = torch.zeros( + max_num_reqs, max_model_len, dtype=torch.int32, device=device + ) self.is_token_ids_tensor = torch.zeros( (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False ) @@ -122,6 +125,9 @@ def __init__( self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec_gpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device=device + ) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs,), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c65a5e9b029..ce90aded2785 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -135,6 +135,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.ngram_proposer_gpu import NgramProposerGPU from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -366,10 +367,16 @@ def __init__( # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( - NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + NgramProposer + | NgramProposerGPU + | SuffixDecodingProposer + | EagleProposer + | MedusaProposer ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "ngram_gpu": + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -959,6 +966,146 @@ def _update_states_after_model_execute( for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None: + """Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu + for ngram GPU proposer to avoid redundant CPU-GPU transfers. + + This follows a similar pattern to _prepare_input_ids for efficient + batch updates when requests change between iterations. + """ + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + curr_req_id_to_index = self.input_batch.req_id_to_index + + # If no previous batch or batch is empty, initialize all from scratch + if prev_req_id_to_index is None or not curr_req_id_to_index: + if curr_req_id_to_index: + # Initialize all token_ids from requests + for req_id, idx in curr_req_id_to_index.items(): + req_state = self.requests[req_id] + # Get prompt_token_ids + output_token_ids + prompt_token_ids = ( + req_state.prompt_token_ids + if req_state.prompt_token_ids is not None + else [] + ) + all_token_ids = prompt_token_ids + req_state.output_token_ids + num_tokens = len(all_token_ids) + # Copy to GPU tensor + self.input_batch.token_ids_gpu_tensor[idx, :num_tokens].copy_( + torch.tensor( + all_token_ids, dtype=torch.int32, device=self.device + ), + non_blocking=True, + ) + self.input_batch.num_tokens_no_spec_gpu[idx] = num_tokens + return + + # Case 1: Batch hasn't changed at all (same req_ids and same indices) + if prev_req_id_to_index == curr_req_id_to_index: + return + + # Case 2, 3 & 4: Batch has changed - analyze the changes + common_req_indices = [] + prev_indices = [] + new_req_indices = [] + indices_match = True + + for req_id, curr_idx in curr_req_id_to_index.items(): + if req_id in prev_req_id_to_index: + prev_idx = prev_req_id_to_index[req_id] + common_req_indices.append(curr_idx) + prev_indices.append(prev_idx) + indices_match &= prev_idx == curr_idx + else: + new_req_indices.append((req_id, curr_idx)) + + # Case 2: Only common requests (subset or same set), may need reordering or clearing + if not new_req_indices: + # If indices haven't changed and it's the exact same set, already handled by Case 1 + # So here we either have reordering or a subset (some requests finished) + if not indices_match or len(common_req_indices) < len(prev_req_id_to_index): + # Need to reorder or clear finished requests + curr_indices_tensor = torch.tensor( + common_req_indices, dtype=torch.long, device=self.device + ) + prev_indices_tensor = torch.tensor( + prev_indices, dtype=torch.long, device=self.device + ) + + # Create temporary tensors for scatter operation (zeros will clear unused positions) + temp_token_ids = torch.zeros_like(self.input_batch.token_ids_gpu_tensor) + temp_num_tokens = torch.zeros_like( + self.input_batch.num_tokens_no_spec_gpu + ) + + # Scatter token_ids - copy entire rows (already up-to-date from prepare_next_token_ids_padded) + temp_token_ids[curr_indices_tensor] = ( + self.input_batch.token_ids_gpu_tensor[prev_indices_tensor] + ) + temp_num_tokens[curr_indices_tensor] = ( + self.input_batch.num_tokens_no_spec_gpu[prev_indices_tensor] + ) + + # Update in-place + self.input_batch.token_ids_gpu_tensor.copy_( + temp_token_ids, non_blocking=True + ) + self.input_batch.num_tokens_no_spec_gpu.copy_( + temp_num_tokens, non_blocking=True + ) + return + + # Case 3: Has new requests (or preempted requests that are resuming) + if new_req_indices: + # First handle common requests with scatter if any + if common_req_indices: + curr_indices_tensor = torch.tensor( + common_req_indices, dtype=torch.long, device=self.device + ) + prev_indices_tensor = torch.tensor( + prev_indices, dtype=torch.long, device=self.device + ) + + # Create temporary tensors for vectorized update + temp_token_ids = torch.zeros_like(self.input_batch.token_ids_gpu_tensor) + temp_num_tokens = torch.zeros_like( + self.input_batch.num_tokens_no_spec_gpu + ) + + # Scatter existing requests to new positions + temp_token_ids[curr_indices_tensor] = ( + self.input_batch.token_ids_gpu_tensor[prev_indices_tensor] + ) + temp_num_tokens[curr_indices_tensor] = ( + self.input_batch.num_tokens_no_spec_gpu[prev_indices_tensor] + ) + + # Copy back to persistent tensors + self.input_batch.token_ids_gpu_tensor.copy_( + temp_token_ids, non_blocking=True + ) + self.input_batch.num_tokens_no_spec_gpu.copy_( + temp_num_tokens, non_blocking=True + ) + + # Then handle new requests + for req_id, curr_idx in new_req_indices: + req_state = self.requests[req_id] + # Get prompt_token_ids + output_token_ids + prompt_token_ids = ( + req_state.prompt_token_ids + if req_state.prompt_token_ids is not None + else [] + ) + all_token_ids = prompt_token_ids + req_state.output_token_ids + num_tokens = len(all_token_ids) + # Copy to GPU tensor with non-blocking + self.input_batch.token_ids_gpu_tensor[curr_idx, :num_tokens].copy_( + torch.tensor(all_token_ids, dtype=torch.int32, device=self.device), + non_blocking=True, + ) + self.input_batch.num_tokens_no_spec_gpu[curr_idx] = num_tokens + def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() assert supports_mrope(model), "M-RoPE support is not implemented." @@ -1345,6 +1492,10 @@ def _prepare_inputs( cu_num_tokens, ) + # For ngram GPU proposer: update token_ids and num_tokens incrementally + if self.speculative_config and self.speculative_config.method == "ngram_gpu": + self._update_ngram_gpu_tensors(scheduler_output) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( @@ -2950,6 +3101,11 @@ def propose_draft_token_ids(sampled_token_ids): and spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch ) + use_padded_batch_for_ngram = ( + self.speculative_config + and self.speculative_config.method == "ngram_gpu" + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len @@ -2989,6 +3145,27 @@ def propose_draft_token_ids(sampled_token_ids): next_token_ids, valid_sampled_tokens_count ) + if use_padded_batch_for_ngram: + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # Fast path: GPU-only operation when input fits in drafter + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + # Slow path: prepare tokens with async transfer + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, @@ -3009,7 +3186,7 @@ def propose_draft_token_ids(sampled_token_ids): if ( self.speculative_config - and not use_padded_batch_for_eagle + and not (use_padded_batch_for_eagle or use_padded_batch_for_ngram) and input_fits_in_drafter ): # ngram and other speculative decoding methods use the sampled @@ -3114,15 +3291,93 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None - if spec_config.method == "ngram": - assert isinstance(sampled_token_ids, list) - assert isinstance(self.drafter, NgramProposer) + if self.speculative_config.method == "ngram": + # TODO:(patchy) NGram GPU proposal + if isinstance(self.drafter, NgramProposer): + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list whenngram is used." + ) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs, + ) + elif self.speculative_config.method == "ngram_gpu": + # GPU-accelerated ngram proposer + assert isinstance(self.drafter, NgramProposerGPU) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for ngram_gpu" + ) + next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1 + + current_lens = self.input_batch.num_tokens_no_spec_gpu[:batch_size] + offsets = torch.arange(max_new_tokens, device=self.device) + + write_positions = current_lens.unsqueeze(1) + offsets.unsqueeze(0) + valid_write_mask = offsets.unsqueeze( + 0 + ) < valid_sampled_tokens_count.unsqueeze(1) + combined_mask = valid_write_mask & (valid_sampled_token_ids_gpu != -1) + + token_ids_slice = self.input_batch.token_ids_gpu_tensor[:batch_size] + write_positions_long = write_positions.long() + existing_values = token_ids_slice.gather(1, write_positions_long) + + tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_slice.dtype) + tokens_to_scatter = torch.where( + combined_mask, + tokens_cast, + existing_values, + ) + token_ids_slice.scatter_(1, write_positions_long, tokens_to_scatter) + + self.input_batch.num_tokens_no_spec_gpu[:batch_size] += ( + valid_sampled_tokens_count + ) + + sampled_flags = valid_sampled_tokens_count > 0 + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device) + + if self.input_batch.spec_decode_unsupported_reqs: + unsupported_ids = torch.tensor( + list(self.input_batch.spec_decode_unsupported_reqs), + dtype=torch.long, + device=self.device, + ) + + batch_req_ids = torch.tensor( + self.input_batch.req_ids[:batch_size], + dtype=torch.long, + device=self.device, + ) + + is_unsupported = ( + batch_req_ids.unsqueeze(1) == unsupported_ids.unsqueeze(0) + ).any(dim=1) + valid_mask = valid_mask & ~is_unsupported + draft_token_ids = self.drafter.propose( - sampled_token_ids, - self.input_batch.req_ids, - self.input_batch.num_tokens_no_spec, - self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs, + self.input_batch.num_tokens_no_spec_gpu[:batch_size], + self.input_batch.token_ids_gpu_tensor[:batch_size], + sampled_flags, + valid_mask, ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list)