diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py index e523090aa217..421fb29a7f87 100644 --- a/vllm/v1/worker/gpu/async_utils.py +++ b/vllm/v1/worker/gpu/async_utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -import numpy as np import torch from vllm.v1.outputs import ( @@ -18,7 +17,7 @@ def __init__( self, model_runner_output: ModelRunnerOutput, sampler_output: SamplerOutput, - num_sampled_tokens: np.ndarray, + num_sampled_tokens: torch.Tensor, copy_stream: torch.cuda.Stream, copy_event: torch.cuda.Event, ): @@ -52,6 +51,7 @@ def __init__( ) else: self.logprobs_tensors = None + self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True) self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} if self.model_runner_output.prompt_logprobs_dict: for k, v in self.model_runner_output.prompt_logprobs_dict.items(): @@ -63,6 +63,7 @@ def __init__( def get_output(self) -> ModelRunnerOutput: self.copy_event.synchronize() + num_sampled_tokens_np = self.num_sampled_tokens.numpy() # NOTE(woosuk): The following code is to ensure compatibility with # the existing model runner. @@ -71,7 +72,7 @@ def get_output(self) -> ModelRunnerOutput: sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist() num_reqs = len(sampled_token_ids) for i in range(num_reqs): - del sampled_token_ids[i][self.num_sampled_tokens[i] :] + del sampled_token_ids[i][num_sampled_tokens_np[i] :] self.model_runner_output.sampled_token_ids = sampled_token_ids if self.logprobs_tensors is not None: diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 222db565dff1..4510a1c5ca1e 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from typing import Any, cast +import numpy as np import torch from vllm.attention.backends.abstract import AttentionBackend @@ -145,8 +146,9 @@ def build_attn_metadata( num_reqs: int, num_tokens: int, query_start_loc: CpuGpuBuffer, - seq_lens: CpuGpuBuffer, - num_computed_tokens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_np: np.ndarray, + num_computed_tokens_cpu: torch.Tensor | None, block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, @@ -154,9 +156,9 @@ def build_attn_metadata( query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] max_query_len = int(query_start_loc.np[: num_reqs + 1].max()) - seq_lens_gpu = seq_lens.gpu[:num_reqs] - seq_lens_cpu = seq_lens.cpu[:num_reqs] - max_seq_len = int(seq_lens.np[:num_reqs].max()) + seq_lens = seq_lens[:num_reqs] + seq_lens_cpu = torch.from_numpy(seq_lens_np) + max_seq_len = int(seq_lens_np.max()) attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups @@ -167,7 +169,7 @@ def build_attn_metadata( common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc_gpu, query_start_loc_cpu=query_start_loc_cpu, - seq_lens=seq_lens_gpu, + seq_lens=seq_lens, seq_lens_cpu=seq_lens_cpu, max_seq_len=max_seq_len, num_computed_tokens_cpu=num_computed_tokens_cpu, diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index ff24e88ede2c..b31e9b179d26 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -3,10 +3,9 @@ from collections.abc import Iterable import torch -import triton -import triton.language as tl from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 31a706475243..4948d5717b95 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -97,14 +97,13 @@ def capture_graph( # Prepare dummy inputs. input_ids = input_buffers.input_ids.gpu[:batch_size] - positions = input_buffers.positions.gpu[:batch_size] + positions = input_buffers.positions[:batch_size] input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) input_buffers.query_start_loc.np[batch_size:] = batch_size input_buffers.query_start_loc.copy_to_gpu() - input_buffers.seq_lens.np[:batch_size] = self.max_model_len - input_buffers.seq_lens.np[batch_size:] = 0 - input_buffers.seq_lens.copy_to_gpu() + input_buffers.seq_lens[:batch_size] = self.max_model_len + input_buffers.seq_lens[batch_size:] = 0 input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables] slot_mappings = block_tables.slot_mappings[:, :batch_size] @@ -115,6 +114,7 @@ def capture_graph( num_tokens=batch_size, query_start_loc=input_buffers.query_start_loc, seq_lens=input_buffers.seq_lens, + seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32), num_computed_tokens_cpu=None, # FIXME block_tables=input_block_tables, slot_mappings=slot_mappings, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 89f375649146..b671c093113b 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -7,9 +7,8 @@ import numba.types as types import numpy as np import torch -import triton -import triton.language as tl +from vllm.triton_utils import tl, triton from vllm.utils import random_uuid from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer @@ -33,9 +32,9 @@ def __init__( self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) + self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) - self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32) + self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) # Structured outputs. self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) @@ -108,13 +107,15 @@ def make_dummy( query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1] query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1] # seq_len equals to query_len - input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens - input_buffers.seq_lens.np[num_reqs:] = 0 - seq_lens_np = input_buffers.seq_lens.np[:num_reqs] - seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs] + seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) + seq_lens_np[-1] += num_tokens % num_reqs + input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs + input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs + input_buffers.seq_lens[num_reqs:] = 0 + seq_lens = input_buffers.seq_lens[:num_reqs] input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) - positions = input_buffers.positions.copy_to_gpu(num_tokens) + positions = input_buffers.positions[:num_tokens] # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 return cls( @@ -142,27 +143,25 @@ def make_dummy( [ types.none( types.int32[:], # idx_mapping - types.int32[:, :], # token_ids - types.int32[:], # num_computed_tokens types.int32[:], # num_scheduled_tokens + types.int32[:, :], # prefill_token_ids + types.int32[:], # num_computed_prefill_tokens + types.int32[:], # prefill_len types.int32[:], # input_ids - types.int64[:], # positions types.int32[:], # query_start_loc - types.int32[:], # seq_lens ) ], nopython=True, cache=True, ) -def _prepare_inputs( +def _prepare_prefill_inputs( idx_mapping: np.ndarray, # batch_idx -> req_idx - token_ids: np.ndarray, # [N, max_model_len] - num_computed_tokens: np.ndarray, # [N] num_scheduled_tokens: np.ndarray, # [B] + prefill_token_ids: np.ndarray, # [N, max_model_len] + num_computed_prefill_tokens: np.ndarray, # [N] + prefill_len: np.ndarray, # [N] input_ids: np.ndarray, # [num_input_tokens] - positions: np.ndarray, # [num_input_tokens] query_start_loc: np.ndarray, # [B + 1] - seq_lens: np.ndarray, # [B] ) -> None: num_reqs = num_scheduled_tokens.shape[0] query_start_loc[0] = 0 @@ -171,62 +170,112 @@ def _prepare_inputs( for i in range(num_reqs): req_idx = idx_mapping[i] query_len = num_scheduled_tokens[i] - start = num_computed_tokens[req_idx] - end = start + query_len - seq_lens[i] = end + + start = num_computed_prefill_tokens[req_idx] + end = min(start + query_len, prefill_len[req_idx]) + n = end - start start_idx = cu_num_tokens - end_idx = start_idx + query_len - input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] - positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64) + input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end] - cu_num_tokens = end_idx + cu_num_tokens = start_idx + query_len query_start_loc[i + 1] = cu_num_tokens # Pad the inputs for CUDA graphs. # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that query_start_loc[num_reqs + 1 :].fill(cu_num_tokens) - # Fill unused with 0 for full cuda graph mode. - seq_lens[num_reqs:].fill(0) -def prepare_inputs( +def prepare_prefill_inputs( idx_mapping: np.ndarray, - prefill_token_ids: np.ndarray, - num_computed_tokens: np.ndarray, num_scheduled_tokens: np.ndarray, + total_num_tokens: int, + prefill_token_ids: np.ndarray, + num_computed_prefill_tokens: np.ndarray, + prefill_len: np.ndarray, input_ids: CpuGpuBuffer, - positions: CpuGpuBuffer, query_start_loc: CpuGpuBuffer, - seq_lens: CpuGpuBuffer, - num_tokens: int, ) -> None: - _prepare_inputs( + _prepare_prefill_inputs( idx_mapping, - prefill_token_ids, - num_computed_tokens, num_scheduled_tokens, + prefill_token_ids, + num_computed_prefill_tokens, + prefill_len, input_ids.np, - positions.np, query_start_loc.np, - seq_lens.np, ) - input_ids.copy_to_gpu(num_tokens) - positions.copy_to_gpu(num_tokens) + input_ids.copy_to_gpu(total_num_tokens) # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens # tensors from CPU to GPU, because they may include paddings needed # for full CUDA graph mode. query_start_loc.copy_to_gpu() - seq_lens.copy_to_gpu() - return @triton.jit -def _combine_last_token_ids_kernel( +def _prepare_pos_seq_lens_kernel( + pos_ptr, + seq_lens_ptr, + idx_mapping_ptr, + query_start_loc_ptr, + num_computed_tokens_ptr, + max_num_reqs, + BLOCK_SIZE: tl.constexpr, +): + req_id = tl.program_id(0) + num_reqs = tl.num_programs(0) - 1 + if req_id == num_reqs: + # Pad unused seq_lens as 0 for full CUDA graphs. + for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < max_num_reqs + tl.store(seq_lens_ptr + block, 0, mask=mask) + return + + req_state_idx = tl.load(idx_mapping_ptr + req_id) + num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx) + + start = tl.load(query_start_loc_ptr + req_id) + end = tl.load(query_start_loc_ptr + req_id + 1) + query_len = end - start + + seq_len = num_computed_tokens + query_len + tl.store(seq_lens_ptr + req_id, seq_len) + + for i in tl.range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + pos = num_computed_tokens + block + tl.store(pos_ptr + start + block, pos, mask=mask) + + +def prepare_pos_seq_lens( + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + num_computed_tokens: torch.Tensor, + pos: torch.Tensor, + seq_lens: torch.Tensor, +) -> None: + num_reqs = idx_mapping.shape[0] + # NOTE(woosuk): We do +1 because the last thread block is used + # to pad unused seq_lens as 0 for full CUDA graphs. + _prepare_pos_seq_lens_kernel[(num_reqs + 1,)]( + pos, + seq_lens, + idx_mapping, + query_start_loc, + num_computed_tokens, + seq_lens.shape[0], + BLOCK_SIZE=1024, + ) + + +@triton.jit +def _combine_sampled_and_draft_tokens_kernel( input_ids_ptr, idx_mapping_ptr, - last_token_ids_ptr, + last_sampled_tokens_ptr, query_start_loc_ptr, seq_lens_ptr, prefill_len_ptr, @@ -240,26 +289,56 @@ def _combine_last_token_ids_kernel( # Handling prefill tokens. return - last_token_id = tl.load(last_token_ids_ptr + req_state_idx) + last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx) end = tl.load(query_start_loc_ptr + batch_idx + 1) tl.store(input_ids_ptr + end - 1, last_token_id) -def combine_last_token_ids( +def combine_sampled_and_draft_tokens( input_ids: torch.Tensor, idx_mapping: torch.Tensor, - last_token_ids: torch.Tensor, + last_sampled_tokens: torch.Tensor, query_start_loc: torch.Tensor, seq_lens: torch.Tensor, prefill_len: torch.Tensor, ) -> torch.Tensor: num_reqs = seq_lens.shape[0] - _combine_last_token_ids_kernel[(num_reqs,)]( + _combine_sampled_and_draft_tokens_kernel[(num_reqs,)]( input_ids, idx_mapping, - last_token_ids, + last_sampled_tokens, query_start_loc, seq_lens, prefill_len, ) return input_ids + + +@triton.jit +def _update_num_computed_tokens_kernel( + idx_mapping_ptr, + num_computed_tokens_ptr, + query_start_loc_ptr, +): + req_id = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + req_id) + + start = tl.load(query_start_loc_ptr + req_id) + end = tl.load(query_start_loc_ptr + req_id + 1) + query_len = end - start + + n = tl.load(num_computed_tokens_ptr + req_state_idx) + tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len) + + +def update_num_computed_tokens( + idx_mapping: torch.Tensor, + num_computed_tokens: torch.Tensor, + query_start_loc: torch.Tensor, +) -> None: + num_reqs = idx_mapping.shape[0] + _update_num_computed_tokens_kernel[(num_reqs,)]( + idx_mapping, + num_computed_tokens, + query_start_loc, + ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9ca37ff282d8..9306f9b3181f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -39,8 +39,10 @@ from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, - combine_last_token_ids, - prepare_inputs, + combine_sampled_and_draft_tokens, + prepare_pos_seq_lens, + prepare_prefill_inputs, + update_num_computed_tokens, ) from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata @@ -196,8 +198,8 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: slot_mappings = self.block_tables.get_dummy_slot_mappings( input_batch.num_tokens ) - num_computed_tokens_cpu = torch.zeros( - input_batch.num_reqs, dtype=torch.int32, device="cpu" + num_computed_tokens = torch.zeros( + input_batch.num_reqs, dtype=torch.int32, device=self.device ) attn_metadata = build_attn_metadata( attn_metadata_builders=self.attn_metadata_builders, @@ -205,7 +207,8 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: num_tokens=input_batch.num_tokens, query_start_loc=self.input_buffers.query_start_loc, seq_lens=self.input_buffers.seq_lens, - num_computed_tokens_cpu=num_computed_tokens_cpu, + seq_lens_np=input_batch.seq_lens_np, + num_computed_tokens_cpu=num_computed_tokens, block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, @@ -367,6 +370,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: cu_num_new_blocks[i].append(x + len(block_ids)) new_block_ids[i].extend(block_ids) overwrite.append(True) + # Update the GPU tensors for request states. + if scheduler_output.scheduled_new_reqs: + self.req_states.prefill_len.copy_to_gpu() # Add new blocks for the existing requests. cached_reqs = scheduler_output.scheduled_cached_reqs @@ -420,46 +426,59 @@ def prepare_inputs( # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) - prepare_inputs( + # Copy prefill tokens from CPU to GPU and get query_start_loc. + prepare_prefill_inputs( idx_mapping_np, - self.req_states.prefill_token_ids, - self.req_states.num_computed_tokens, num_scheduled_tokens, + num_tokens, + self.req_states.prefill_token_ids, + self.req_states.num_computed_prefill_tokens, + self.req_states.prefill_len.np, self.input_buffers.input_ids, - self.input_buffers.positions, self.input_buffers.query_start_loc, - self.input_buffers.seq_lens, - num_tokens, ) - query_start_loc = self.input_buffers.query_start_loc query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] query_start_loc_np = query_start_loc.np[: num_reqs + 1] - seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs] - seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs] - # Some input token ids are directly read from the last sampled tokens. - combine_last_token_ids( + # Prepare positions and seq_lens. + prepare_pos_seq_lens( + idx_mapping, + query_start_loc_gpu, + self.req_states.num_computed_tokens, + self.input_buffers.positions, + self.input_buffers.seq_lens, + ) + seq_lens = self.input_buffers.seq_lens[:num_reqs] + + # Some input token ids are directly read from the last sampled tokens + # and draft tokens. + combine_sampled_and_draft_tokens( self.input_buffers.input_ids.gpu, idx_mapping, self.req_states.last_sampled_tokens, query_start_loc_gpu, - seq_lens_gpu, - self.req_states.prefill_len.copy_to_gpu(), + seq_lens, + self.req_states.prefill_len.gpu, ) # Compute slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens] - ) - - num_computed_tokens_cpu = torch.from_numpy( - self.req_states.num_computed_tokens[idx_mapping_np] + query_start_loc_gpu, self.input_buffers.positions[:num_tokens] ) # Logits indices to sample next token from. logits_indices = query_start_loc_gpu[1:] - 1 + # Get num_computed_tokens. + # HACK(woosuk): Here, we use num_computed_tokens on GPU instead of + # num_computed_tokens_cpu. This works for most cases. + num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping] + # HACK(woosuk): Only GPU has the exact seq_lens because at this point + # CPU does not know how many draft tokens are accepted/rejected in the + # previous step. Therefore, we use max_model_len to be safe. + seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32) + # Layer name -> attention metadata. attn_metadata = build_attn_metadata( attn_metadata_builders=self.attn_metadata_builders, @@ -467,14 +486,15 @@ def prepare_inputs( num_tokens=num_tokens, query_start_loc=self.input_buffers.query_start_loc, seq_lens=self.input_buffers.seq_lens, - num_computed_tokens_cpu=num_computed_tokens_cpu, + seq_lens_np=seq_lens_np, + num_computed_tokens_cpu=num_computed_tokens, block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, ) input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] - positions = self.input_buffers.positions.gpu[:num_tokens_after_padding] + positions = self.input_buffers.positions[:num_tokens_after_padding] return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -485,7 +505,7 @@ def prepare_inputs( num_tokens_after_padding=num_tokens_after_padding, query_start_loc=query_start_loc_gpu, query_start_loc_np=query_start_loc_np, - seq_lens=seq_lens_gpu, + seq_lens=seq_lens, seq_lens_np=seq_lens_np, input_ids=input_ids, positions=positions, @@ -499,11 +519,12 @@ def sample( input_batch: InputBatch, sampling_metadata: SamplingMetadata, grammar_output: GrammarOutput | None, - ) -> SamplerOutput: + ) -> tuple[SamplerOutput, torch.Tensor]: sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states) if grammar_output is not None: # Apply grammar bitmask to the logits in-place. + # TODO(woosuk): Make compatible with spec decoding. with async_barrier(self.structured_outputs_event): apply_grammar_bitmask( logits, @@ -512,8 +533,14 @@ def sample( grammar_output.grammar_bitmask, self.input_buffers, ) + sampler_output = self.sampler(logits, sampling_metadata) - return sampler_output + # Get the number of sampled tokens. + # 0 if chunked-prefilling, 1 if not. + prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping] + is_chunked_prefilling = input_batch.seq_lens < prefill_len + num_sampled = (~is_chunked_prefilling).int() + return sampler_output, num_sampled def compute_prompt_logprobs( self, @@ -526,11 +553,11 @@ def compute_prompt_logprobs( # No request asks for prompt logprobs. return {} - num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np] prompt_lens = self.req_states.prompt_len[idx_mapping_np] # NOTE(woosuk): -1 because the last prompt token's hidden state is not # needed for prompt logprobs. - includes_prompt = num_computed_tokens < prompt_lens - 1 + computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np] + includes_prompt = computed_prefill < prompt_lens - 1 # NOTE(woosuk): If the request was resumed after preemption, its prompt # logprobs must have been computed before preemption. Skip. resumed_after_prompt = ( @@ -549,8 +576,8 @@ def compute_prompt_logprobs( token_ids[n - 1] = 0 # Handle chunked prompts. - seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs] - is_prompt_chunked = seq_lens < prompt_lens + pos_after_step = computed_prefill + input_batch.num_scheduled_tokens + is_prompt_chunked = pos_after_step < prompt_lens prefill_token_ids = self.req_states.prefill_token_ids query_start_loc = self.input_buffers.query_start_loc.np for i, req_id in enumerate(input_batch.req_ids): @@ -560,7 +587,7 @@ def compute_prompt_logprobs( continue # The prompt is chunked. Get the next prompt token. req_idx = input_batch.idx_mapping_np[i] - next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]]) + next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]]) idx = int(query_start_loc[i + 1] - 1) # Set the next prompt token. # NOTE(woosuk): This triggers a GPU operation. @@ -616,48 +643,27 @@ def compute_prompt_logprobs( def postprocess( self, - sampler_output: SamplerOutput, - prompt_logprobs_dict: dict[str, LogprobsTensors], input_batch: InputBatch, - ) -> AsyncOutput | ModelRunnerOutput: - # Store the last sampled token ids. - self.req_states.last_sampled_tokens[input_batch.idx_mapping] = ( - sampler_output.sampled_token_ids + sampled_tokens: torch.Tensor, + num_sampled: torch.Tensor, + ) -> None: + # Update the number of computed tokens. + update_num_computed_tokens( + input_batch.idx_mapping, + self.req_states.num_computed_tokens, + input_batch.query_start_loc, ) - # Get the number of sampled tokens. - # 0 if chunked-prefilling, 1 if not. idx_mapping_np = input_batch.idx_mapping_np - is_chunked_prefilling = ( - input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np] - ) - num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) - # Increment the number of tokens. - self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens - # Increment the number of computed tokens. - self.req_states.num_computed_tokens[idx_mapping_np] += ( - input_batch.num_scheduled_tokens + computed_prefill = self.req_states.num_computed_prefill_tokens + # TODO(woosuk): Simplify this. + computed_prefill[idx_mapping_np] = np.minimum( + computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens, + self.req_states.prefill_len.np[idx_mapping_np], ) - model_runner_output = ModelRunnerOutput( - req_ids=input_batch.req_ids, - req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, - sampled_token_ids=None, # type: ignore - logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore - pooler_output=[], - kv_connector_output=None, - num_nans_in_logits=None, - ) - async_output = AsyncOutput( - model_runner_output=model_runner_output, - sampler_output=sampler_output, - num_sampled_tokens=num_sampled_tokens, - copy_stream=self.output_copy_stream, - copy_event=self.output_copy_event, - ) - if self.use_async_scheduling: - return async_output - return async_output.get_output() + # Store the last sampled token ids. + last_sampled = sampled_tokens + self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled def get_cudagraph_and_dp_padding( self, @@ -781,6 +787,7 @@ def execute_model( ) else: # Run PyTorch model in eager mode. + # TODO(woosuk): Support piecewise CUDA graph. with set_forward_context( input_batch.attn_metadata, self.vllm_config, @@ -806,13 +813,41 @@ def sample_tokens( self.execute_model_state = None # type: ignore assert sampling_metadata is not None - sampler_output = self.sample( + sampler_output, num_sampled_tokens = self.sample( hidden_states, input_batch, sampling_metadata, grammar_output ) prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) - output = self.postprocess( - sampler_output, - prompt_logprobs_dict, - input_batch, + + # Prepare the model runner output. + model_runner_output = ModelRunnerOutput( + req_ids=input_batch.req_ids, + # NOTE(woosuk): req_id_to_index is unused in this model runner. + # Only for compatibility with the existing model runner and scheduler. + req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, + sampled_token_ids=None, # type: ignore + logprobs=None, + prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore + pooler_output=[], + kv_connector_output=None, + num_nans_in_logits=None, + ) + async_output = AsyncOutput( + model_runner_output=model_runner_output, + sampler_output=sampler_output, + num_sampled_tokens=num_sampled_tokens, + copy_stream=self.output_copy_stream, + copy_event=self.output_copy_event, + ) + + # Postprocess results and update request states. + # NOTE: This is intentionally done after creating the AsyncOutput, + # ensuring that `copy_event` is recorded before calling postprocess. + # This sequencing may slightly reduce latency as async D2H copy does not + # need to wait for the postprocess to finish. + self.postprocess( + input_batch, sampler_output.sampled_token_ids, num_sampled_tokens ) - return output + + if self.use_async_scheduling: + return async_output + return async_output.get_output() diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 5d05c3f57790..e8a3207a3a53 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -85,8 +85,12 @@ def __init__( dtype=np.int32, ) self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) - self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + + # Number of computed tokens. + self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + self.num_computed_tokens = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) # Last sampled tokens. self.last_sampled_tokens = torch.zeros( @@ -145,7 +149,10 @@ def add_request( ) self.prefill_len.np[req_idx] = prefill_len self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids - self.num_tokens[req_idx] = prefill_len + + self.num_computed_prefill_tokens[req_idx] = num_computed_tokens + # FIXME(woosuk): This triggers a GPU operation whenever adding a new request. + # Optimize this. self.num_computed_tokens[req_idx] = num_computed_tokens if lora_request is not None: