From 3eb37c1225c4d312a88b8f9fe49c346b8112af9e Mon Sep 17 00:00:00 2001 From: 01267596 Date: Wed, 25 Mar 2026 08:55:12 +0000 Subject: [PATCH 1/4] [Async][spec decode] Zero-bubble async scheduling +spec decoding Signed-off-by: 01267596 --- vllm_ascend/attention/attention_v1.py | 2 +- vllm_ascend/attention/utils.py | 2 +- vllm_ascend/spec_decode/eagle_proposer.py | 22 ++- vllm_ascend/worker/block_table.py | 109 ++++-------- vllm_ascend/worker/model_runner_v1.py | 200 +++++++++++++++++----- vllm_ascend/worker/npu_input_batch.py | 2 +- 6 files changed, 202 insertions(+), 135 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 015dd90b7c0..67bb0744c72 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -276,7 +276,7 @@ def build( ) block_table = common_attn_metadata.block_table_tensor - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] # this slot_mapping override doesn't work since vllm will override it again. We should fix it vllm. diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 946d5c66d4e..c5513745b0c 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -169,7 +169,7 @@ def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommo query_start_loc=self.query_start_loc[: num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs] if self.seq_lens_cpu is not None else None, num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs], num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 72ee9d9ca4a..cf4099eabd8 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -90,7 +90,7 @@ class SpecDecodeBaseProposer(EagleProposer): def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, runner=None): super().__init__(vllm_config, device, runner) - + self.runner = runner self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling self.pass_hidden_states_to_model = pass_hidden_states_to_model self.decode_threshold = 1 + self.num_speculative_tokens @@ -367,7 +367,7 @@ def dummy_run( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], - seq_lens_cpu=self.runner.seq_lens.cpu, + seq_lens_cpu=self.runner.optimistic_seq_lens_cpu, seq_lens=self.runner.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, @@ -531,7 +531,7 @@ def _propose( common_attn_metadata.block_table_tensor, num_reqs_padded ) common_attn_metadata.seq_lens = self.runner.seq_lens.gpu[:num_reqs_padded] - common_attn_metadata.seq_lens_cpu = self.runner.seq_lens.cpu[:num_reqs_padded] + common_attn_metadata.seq_lens_cpu = self.runner.optimistic_seq_lens_cpu[:num_reqs_padded] if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) @@ -1177,10 +1177,10 @@ def attn_update_stack_num_spec_norm( # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. common_attn_metadata.seq_lens[:batch_size].masked_fill_(exceeds_max_model_len, 1) - - common_attn_metadata.seq_lens_cpu[:batch_size] = common_attn_metadata.seq_lens_cpu[:batch_size] + 1 - exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= self.max_model_len - common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(exceeds_mask, 1) + if common_attn_metadata.seq_lens_cpu is not None: + common_attn_metadata.seq_lens_cpu[:batch_size] = common_attn_metadata.seq_lens_cpu[:batch_size] + 1 + exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= self.max_model_len + common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(exceeds_mask, 1) common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 if self.uses_mrope: common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) @@ -1244,7 +1244,7 @@ def attn_update_stack_num_spec_norm( def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + seq_lens_cpu: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -1264,11 +1264,9 @@ def prepare_next_token_ids_padded( # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs + seq_lens_list = seq_lens_cpu[:num_reqs].tolist() self.backup_next_token_ids.np[:num_reqs] = np.array( - [ - requests[gpu_input_batch.req_ids[i]].get_token_id(common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ] + [requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) for i in range(num_reqs)] ) self.backup_next_token_ids.copy_to_gpu(num_reqs) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 3c812aa4432..f4bf6d4e344 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -2,7 +2,9 @@ import torch from vllm.distributed import get_dcp_group, get_pcp_group from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.block_table import _compute_slot_mapping_kernel from vllm.v1.worker.cp_utils import get_total_cp_world_size @@ -117,80 +119,34 @@ def swap_row(self, src: int, tgt: int) -> None: self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]] - def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None: - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - - if self.dcp_world_size * self.pcp_world_size > 1: - # Note(hc): The DCP implement store kvcache with an interleave - # style, the kvcache for the token whose token_idx is i is - # always stored on the GPU whose dcp_rank equals i % pcp_world_size: - - # Use a "virtual block" which equals to world_size * block_size - # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size - - # IMPORTANT: In hybrid mode, positions are in logical block space, - # but we need to map them to the correct logical block table indices - logical_block_idx = positions // virtual_block_size - - # Account for the expanded logical table - # (always needed with unified tensor) - # Each physical block is split into multiple logical blocks - # The logical table has been expanded to accommodate this - block_table_indices = ( - req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx - ) - - block_numbers = self.block_table.np.ravel()[block_table_indices] - # Use virtual_block_size for mask calculation, which marks local - # tokens. - virtual_block_offsets = positions % virtual_block_size - self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank - mask = ( - virtual_block_offsets // self.cp_kv_cache_interleave_size % (self.dcp_world_size * self.pcp_world_size) - == self.current_rank - ) - # Calculate local block_offsets - block_offsets = ( - virtual_block_offsets - // (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) - * self.cp_kv_cache_interleave_size - + virtual_block_offsets % self.cp_kv_cache_interleave_size - ) - # Calculate slot_mapping - slot_mapping = block_numbers * self.block_size + block_offsets - # Write final slots, use -1 for not-local - self.slot_mapping.np[: req_indices.shape[0]] = np.where(mask, slot_mapping, -1) - else: - assert self.kernel_sizes is not None - if self.block_size == self.kernel_sizes[0]: - # IMPORTANT: In hybrid mode, positions are in logical block space, - # but we need to map them to the correct logical block table indices - logical_block_idx = positions // self.block_size - - # Account for the expanded logical table - # (always needed with unified tensor) - # Each physical block is split into multiple logical blocks - # The logical table has been expanded to accommodate this - block_table_indices = ( - req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx - ) - - block_numbers = self.block_table.np.ravel()[block_table_indices] - block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping.np[: req_indices.shape[0]]) + def compute_slot_mapping( + self, + num_reqs: int, + query_start_loc: torch.Tensor, + positions: torch.Tensor, + ) -> None: + num_tokens = positions.shape[0] + total_cp_world_size = self.pcp_world_size * self.dcp_world_size + total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank + _compute_slot_mapping_kernel[(num_reqs + 1,)]( + num_tokens, + self.max_num_batched_tokens, + query_start_loc, + positions, + self.block_table.gpu, + self.block_table.gpu.stride(0), + self.block_size, + self.slot_mapping.gpu, + TOTAL_CP_WORLD_SIZE=total_cp_world_size, + TOTAL_CP_RANK=total_cp_rank, + CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size, + PAD_ID=PAD_SLOT_ID, + BLOCK_SIZE=1024, + ) def commit_block_table(self, num_reqs: int) -> None: self.block_table.copy_to_gpu(num_reqs) - def commit_slot_mapping(self, num_tokens: int) -> None: - self.slot_mapping.copy_to_gpu(num_tokens) - def clear(self) -> None: self.block_table.fill_(0) self.block_table.cpu.fill_(0) @@ -299,18 +255,19 @@ def swap_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None: + def compute_slot_mapping( + self, + num_reqs: int, + query_start_loc: torch.Tensor, + positions: torch.Tensor, + ) -> None: for block_table in self.block_tables: - block_table.compute_slot_mapping(req_indices, positions) + block_table.compute_slot_mapping(num_reqs, query_start_loc, positions) def commit_block_table(self, num_reqs: int) -> None: for block_table in self.block_tables: block_table.commit_block_table(num_reqs) - def commit_slot_mapping(self, num_tokens: int) -> None: - for block_table in self.block_tables: - block_table.commit_slot_mapping(num_tokens) - def clear(self) -> None: for block_table in self.block_tables: block_table.clear() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 677e9925a82..6f72ca58e81 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -75,6 +75,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker import mamba_utils @@ -635,13 +636,13 @@ def _prepare_inputs( self.with_prefill = with_prefill # Get positions. - positions_np = self.positions.np[:total_num_scheduled_tokens] - cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - + cu_num_tokens = self._get_cumsum_and_arange( + num_scheduled_tokens, self.query_pos.np + ) + positions_np = ( + self.input_batch.num_computed_tokens_cpu[req_indices] + + self.query_pos.np[: cu_num_tokens[-1]] + ) if self.use_cp: self.pcp_manager.init_batch_info( num_scheduled_tokens, @@ -758,15 +759,28 @@ def _prepare_inputs( self.gdn_query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) self.gdn_query_start_loc.copy_to_gpu() - self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens - self.seq_lens.cpu[num_reqs:].fill_(0) - self.seq_lens.copy_to_gpu() + + # Compute optimistic seq_lens (assumes all draft tokens from previous + # iteration accepted). Store in optimistic_seq_lens_cpu for use by + # _build_attention_metadata (max_seq_len) and discard_request_mask. + # seq_lens (GPU) will be computed later using the same optimistic values. + torch.add( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], + torch.from_numpy(num_scheduled_tokens), + out=self.optimistic_seq_lens_cpu[:num_reqs], + ) + self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) + + # Build prev_positions mapping: current pos -> prev pos (-1 if new). + # Used for gathering from previous iteration's GPU tensors. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + self._compute_prev_positions(num_reqs) # Fill unused with -1. Needed for reshape_and_cache in attention_cp self.query_start_loc.gpu[num_reqs + 1 :].fill_(-1) # Copy the tensors to the NPU. - self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, cu_num_tokens) + self._prepare_input_ids(scheduler_output, num_reqs, total_num_scheduled_tokens, cu_num_tokens) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -783,9 +797,6 @@ def _prepare_inputs( self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True, ) - else: - # Common case (1D positions) - self.positions.copy_to_gpu(total_num_scheduled_tokens) # Record the index of requests that should not be sampled, # so that we could clear the sampled tokens before returning @@ -803,12 +814,86 @@ def _prepare_inputs( ) discard_requests_mask = original_seq_lens_np < num_tokens_np else: - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_requests_mask = self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) self.discard_request_indices.np[: self.num_discarded_requests] = discard_request_indices self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + + # Sync num_accepted_tokens from CPU (set by + # _update_states_after_model_execute for hybrid models). + if self.num_accepted_tokens_event is not None: + self.num_accepted_tokens_event.synchronize() + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + else: + self.num_accepted_tokens.np.fill(1) + self.num_accepted_tokens.gpu.fill_(1) + + # Update num_computed_tokens on GPU. In async spec decode, + # CPU values are optimistic (all drafts accepted). The kernel + # corrects on GPU using the previous step's + # valid_sampled_token_count_gpu. Otherwise, just copy from CPU. + if ( + self.use_async_spec_decode + and self.valid_sampled_token_count_gpu is not None + and prev_req_id_to_index + ): + self.prev_positions.copy_to_gpu(num_reqs) + self.prev_num_draft_tokens.copy_to_gpu() + cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to( + device=self.device, non_blocking=True + ) + update_num_computed_tokens_for_batch_change( + self.num_computed_tokens, + self.num_accepted_tokens.gpu[:num_reqs], + self.prev_positions.gpu[:num_reqs], + self.valid_sampled_token_count_gpu, + self.prev_num_draft_tokens.gpu, + cpu_values, + ) + else: + self.num_computed_tokens[:num_reqs].copy_( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], + non_blocking=True, + ) + + self.req_indices.np[:total_num_scheduled_tokens] = req_indices + self.req_indices.copy_to_gpu(total_num_scheduled_tokens) + req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens] + + self.query_pos.copy_to_gpu(total_num_scheduled_tokens) + self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens + self.num_scheduled_tokens.copy_to_gpu(num_reqs) + num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs] + self.positions[:total_num_scheduled_tokens] = ( + self.num_computed_tokens[req_indices_gpu].to(torch.int64) + + self.query_pos.gpu[:total_num_scheduled_tokens] + ) + self.seq_lens[:num_reqs] = ( + self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu + ) + self.seq_lens[num_reqs:].fill_(0) + + self.input_batch.block_table.compute_slot_mapping( + num_reqs, + self.query_start_loc.gpu[: num_reqs + 1], + self.positions[:total_num_scheduled_tokens], + ) + + if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0): + drift = self.num_computed_tokens[req_indices_gpu].to( + torch.int64 + ) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to( + device=self.device, dtype=torch.int64, non_blocking=True + ) + target = self.mrope_positions if self.uses_mrope else self.xdrope_positions + target.gpu[:, :total_num_scheduled_tokens] += drift + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain @@ -838,11 +923,12 @@ def _prepare_inputs( draft_token_ids, ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) + draft_len = len(draft_token_ids) + num_draft_tokens[req_idx] = draft_len if (self.is_kv_consumer and req_id in new_schedule_reqs) or \ (self.input_batch.num_computed_tokens_cpu[req_idx] >= \ self.input_batch.num_prompt_tokens[req_idx]): - num_decode_draft_tokens[req_idx] = len(draft_token_ids) + num_decode_draft_tokens[req_idx] = draft_len else: num_decode_draft_tokens[req_idx] = -1 @@ -925,24 +1011,23 @@ def _calc_spec_decode_metadata( # Compute the logits indices. # [4, 1, 3, 1, 2] num_sampled_tokens = num_draft_tokens + 1 - # Step 1. [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) - total_num_sampled_tokens = cu_num_sampled_tokens[-1] - # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets - # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + # Step 1. + # cu_num_sampled_tokens: [4, 5, 8, 9, 11] + # _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + cu_num_sampled_tokens = self._get_cumsum_and_arange( + num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32 + ) + # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat(cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - logits_indices += arange + # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]] # while pcp > 1, decode results may contain padding (from pcp all-gather), # update logits_indices after getting draft_token_ids from ori logits_indices if self.pcp_size > 1: cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads logits_indices_pcp = np.repeat(cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) - logits_indices_pcp += arange + logits_indices_pcp += self._arange_scratch[: cu_num_sampled_tokens[-1]] logits_indices_pcp = torch.from_numpy(logits_indices_pcp).pin_memory().to(self.device, non_blocking=True) # Compute the bonus logits indices. @@ -1033,7 +1118,7 @@ def propose_draft_token_ids( ) assert self.drafter is not None next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -1162,7 +1247,7 @@ def execute_model( with record_function_or_nullcontext("prepare input"): with self.synchronize_input_prep(): # Update persistent batch states. - self._update_states(scheduler_output) + deferred_state_corrections_fn = self._update_states(scheduler_output) if has_ec_transfer() and get_ec_transfer().is_producer: with self.maybe_get_ec_connector_output( @@ -1265,6 +1350,12 @@ def execute_model( # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. # We simply utilize the implementation in vLLM. if self.cache_config.mamba_cache_mode == "align": + # preprocess_mamba reads req_state.num_computed_tokens (CPU) + # to decide copy operations, so we must apply deferred + # corrections before it runs. + if deferred_state_corrections_fn: + deferred_state_corrections_fn() + deferred_state_corrections_fn = None mamba_utils.preprocess_mamba( scheduler_output, self.kv_cache_config, @@ -1276,6 +1367,14 @@ def execute_model( self.model.get_mamba_state_copy_func(), self._get_mamba_copy_bufs(), ) + # preprocess_mamba resets num_accepted_tokens_cpu to 1 + # for requests whose state was copied to a new block. + # Re-sync to GPU so the mamba kernel reads from the + # correct initial state slot (init_token_idx = 0). + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.copy_to_gpu(num_reqs) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -1463,6 +1562,11 @@ def execute_model( batch_desc, ) self.kv_connector_output = kv_connector_output + + # Now the batch has been launched we can wait for corrections from the + # previous model forward without breaking async scheduling. + if deferred_state_corrections_fn: + deferred_state_corrections_fn() return None @torch.inference_mode() @@ -1527,6 +1631,8 @@ def sample_tokens( assert self.sampling_done_event is not None self.sampling_done_event.record() + self.valid_sampled_token_count_gpu = None + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( @@ -2040,11 +2146,8 @@ def _build_attention_metadata( # window size when capturing to make sure the correct kernel is selected. max_seq_len = self.max_model_len else: - max_seq_len = self.seq_lens.np[:num_reqs].max().item() - if use_spec_decode and self.need_accepted_tokens: - self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() + max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item() + kv_cache_groups = self.kv_cache_config.kv_cache_groups @@ -2109,14 +2212,21 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) self.long_seq_metadata, block_table_gid_0 = _get_pcp_metadata(block_table_gid_0) + seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded] + if self.use_async_spec_decode: + # GPU tensors are authoritative in async mode. + seq_lens_cpu = None + num_computed_tokens_cpu = None + cm_base = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], - seq_lens=self.seq_lens.gpu[:num_reqs_padded], + seq_lens=self.seq_lens[:num_reqs_padded], # TODO - seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], + seq_lens_cpu=seq_lens_cpu, # TODO - num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs_padded], + # num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs_padded], + num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens, max_query_len=max_query_len, @@ -2126,7 +2236,7 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): causal=True, num_input_tokens=num_tokens_padded, actual_seq_lengths_q=self.actual_seq_lengths_q, - positions=self.positions.gpu, + positions=self.positions, attn_state=self.attn_state, decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=self.long_seq_metadata, @@ -2390,11 +2500,13 @@ def _dummy_run( if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len ) # type: ignore[assignment] - self.seq_lens.np[:num_reqs_padded] = seq_lens - self.seq_lens.np[num_reqs_padded:] = 0 - self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens + self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) + self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True) + + cum_num_tokens = self._get_cumsum_and_arange( + num_scheduled_tokens, self.query_pos.np) self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() num_reqs_padded = self._pad_query_start_loc_for_fia( @@ -2436,7 +2548,7 @@ def _dummy_run( elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: - positions = self.positions.gpu[:num_tokens_padded] + positions = self.positions[:num_tokens_padded] # update global cos, sin update_cos_sin(positions) diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index b5f21a6a986..5e59ba26c8e 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -152,7 +152,7 @@ def __init__( # Speculative decoding self.num_accepted_tokens_cpu_tensor = torch.ones( - (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory ) self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() From 0a5a5f40547e0ca96fdb71f8372c1a6355461f39 Mon Sep 17 00:00:00 2001 From: 01267596 Date: Wed, 25 Mar 2026 09:22:32 +0000 Subject: [PATCH 2/4] [Async][spec decode] Zero-bubble async scheduling +spec decoding Signed-off-by: 01267596 --- vllm_ascend/attention/utils.py | 2 +- vllm_ascend/spec_decode/eagle_proposer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index c5513745b0c..1c4245ad654 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -170,7 +170,7 @@ def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommo query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs] if self.seq_lens_cpu is not None else None, - num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs], + num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs] if self.num_computed_tokens_cpu is not None else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index cf4099eabd8..1545b9c4e60 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1181,7 +1181,8 @@ def attn_update_stack_num_spec_norm( common_attn_metadata.seq_lens_cpu[:batch_size] = common_attn_metadata.seq_lens_cpu[:batch_size] + 1 exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= self.max_model_len common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(exceeds_mask, 1) - common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 + if common_attn_metadata.num_computed_tokens_cpu is not None: + common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 if self.uses_mrope: common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) else: From 34967a116c78ec8d4ee68abc8ff45c98e1dd8130 Mon Sep 17 00:00:00 2001 From: 01267596 Date: Wed, 25 Mar 2026 10:20:48 +0000 Subject: [PATCH 3/4] [Async][spec decode] Zero-bubble async scheduling +spec decoding Signed-off-by: 01267596 --- vllm_ascend/worker/model_runner_v1.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6f72ca58e81..75469a1c30d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2211,7 +2211,9 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) self.long_seq_metadata, block_table_gid_0 = _get_pcp_metadata(block_table_gid_0) - + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ] seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded] if self.use_async_spec_decode: # GPU tensors are authoritative in async mode. From 210a15336b8c4db4ea01a9ff07f598f84a4af040 Mon Sep 17 00:00:00 2001 From: 01267596 Date: Fri, 27 Mar 2026 02:47:31 +0000 Subject: [PATCH 4/4] optimize Signed-off-by: 01267596 --- vllm_ascend/spec_decode/utils.py | 35 +++++++++++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/spec_decode/utils.py diff --git a/vllm_ascend/spec_decode/utils.py b/vllm_ascend/spec_decode/utils.py new file mode 100644 index 00000000000..1d8d82fdcce --- /dev/null +++ b/vllm_ascend/spec_decode/utils.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + + +def update_num_computed_tokens_for_batch_change( + num_computed_tokens: torch.Tensor, + num_accepted_tokens: torch.Tensor, + prev_positions: torch.Tensor, + valid_sampled_token_count: torch.Tensor, + prev_num_draft_tokens: torch.Tensor, + cpu_num_computed_tokens: torch.Tensor, +) -> None: + """Correct num_computed_tokens for async spec decode drift. + + Requests that had drafts: corrected = prev_gpu + valid_count. + New requests or non-draft (e.g. prefills): use CPU value directly. + """ + # Clamp because prev_positions can be -1 for new requests + gather_indices = prev_positions.clamp(min=0) + + valid_counts = valid_sampled_token_count[gather_indices] + prev_computed = num_computed_tokens[gather_indices] + prev_drafts = prev_num_draft_tokens[gather_indices] + + participating = (prev_positions >= 0) & (prev_drafts > 0) + corrected = prev_computed + valid_counts.int() + + n = prev_positions.shape[0] + num_computed_tokens[:n].copy_( + torch.where(participating, corrected, cpu_num_computed_tokens) + ) + num_accepted_tokens.copy_( + torch.where(participating, valid_counts, num_accepted_tokens) + ) \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 75469a1c30d..960bd3df5d0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -75,7 +75,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker import mamba_utils @@ -117,6 +116,7 @@ from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer +from vllm_ascend.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm_ascend.utils import ( calc_split_factor, check_gdn_layer,