diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py new file mode 100644 index 0000000000..64076c2bda --- /dev/null +++ b/vllm_ascend/spec_decode/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py +# +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.mtp_proposer import MtpProposer +from vllm_ascend.spec_decode.ngram_proposer import NgramProposer + + +def get_spec_decode_method(method, vllm_config, device, runner): + if method == "ngram": + return NgramProposer(vllm_config, device, runner) + elif method in ["eagle", "eagle3"]: + return EagleProposer(vllm_config, device, runner) + elif method == 'deepseek_mtp': + return MtpProposer(vllm_config, device, runner) + else: + raise ValueError("Unknown speculative decoding method: " + f"{method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py new file mode 100644 index 0000000000..c7c35bbdf2 --- /dev/null +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -0,0 +1,643 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from vllm.attention.layer import Attention +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + +PADDING_SLOT_ID = -1 + + +class EagleProposer(Proposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3 + self.vllm_config = vllm_config + self.device = device + self.runner = runner + + self.block_size = vllm_config.cache_config.block_size + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size( + ) + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=device) + self.positions = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (self.vllm_config.scheduler_config.max_num_batched_tokens, + self.hidden_size), + dtype=self.vllm_config.model_config.dtype, + device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=device, + dtype=torch.int32) + attn_mask_len = min(self.vllm_config.model_config.max_model_len, + int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))) + self.attn_mask_builder = AttentionMaskBuilder( + attn_mask_len, self.vllm_config.model_config.dtype) + + def load_model(self, model: nn.Module) -> None: + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + self.attn_layer_name = next(iter(draft_attn_layer_names)) + + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1: + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = model.model.embed_tokens + else: + logger.info( + "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"): + logger.info("Loading EAGLE LM head weights from the target model.") + if supports_multimodal(model): + self.model.lm_head = model.get_language_model().lm_head + else: + self.model.lm_head = model.lm_head + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None): + with set_ascend_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + if self.name == SpecDcodeType.EAGLE: + raise NotImplementedError("Eagle Is Not Supported Yet.") + + attn_metadata = self._get_eagle_atten_dict(scheduler_output) + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + eagle_attn_metadata = attn_metadata[self.attn_layer_name] + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc + else: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + cu_num_tokens, token_indices = self._prepare_inputs( + eagle_attn_metadata.query_start_loc, num_rejected_tokens, + num_tokens) + target_token_ids = self.runner.input_ids[token_indices] + target_positions = positions[token_indices] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] + + draft_token_ids = self._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=eagle_attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _get_eagle_atten_dict( + self, + scheduler_output: "SchedulerOutput", + ): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.runner.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.runner.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.runner.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + self.runner.query_lens = torch.from_numpy(num_scheduled_tokens) + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.runner.arange_np[:num_reqs], + num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + + # Get positions. + positions_np = self.runner.positions_np[:total_num_scheduled_tokens] + np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.runner.uses_mrope: + self.runner._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = ( + positions_np + + req_indices * self.runner.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.runner.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.runner.input_ids_cpu[:total_num_scheduled_tokens]) + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.runner.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table = self.runner.input_batch.block_table[ + kv_cache_group_id] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.runner.query_start_loc_np[0] = 0 + self.runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + + self.runner.seq_lens_np[:num_reqs] = ( + self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + + # Copy the tensors to the NPU. + self.runner.input_ids[:total_num_scheduled_tokens].copy_( + self.runner.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + if self.runner.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.runner.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.runner. + mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.runner.positions[:total_num_scheduled_tokens].copy_( + self.runner.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + self.runner.query_start_loc[:num_reqs + 1].copy_( + self.runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.runner.seq_lens[:num_reqs].copy_( + self.runner.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.runner.seq_lens[num_reqs:].fill_(0) + self.runner.query_start_loc[num_reqs + 1:].fill_(-1) + + attn_metadata = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.runner.kv_cache_config.kv_cache_groups): + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + num_reqs=num_reqs, + max_query_len=max_num_scheduled_tokens, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=self.runner.slot_mapping_cpu, + positions=self.runner.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata_i = self.runner.attn_metadata_builder.build( + common_attn_metadata, self.runner.get_model()) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + return attn_metadata + + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.runner.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def _propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + device = cu_num_tokens.device + cu_num_tokens = cu_num_tokens.cpu() + block_table = block_table.cpu() + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + target_positions = target_positions.cpu() + if self.name == SpecDcodeType.EAGLE3: + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids[0] + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:batch_size + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + max_query_len=max_query_len, + num_reqs=batch_size, + num_actual_tokens=num_tokens, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + ) + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata, self.runner.model) + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions.to(device) + self.hidden_states[:num_tokens] = target_hidden_states + attn_metadata.block_tables = block_table.to(device) + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + ) + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.vllm_config.speculative_config.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + # Generate the remaining draft tokens. + draft_token_ids_tensor = torch.zeros( + (self.vllm_config.speculative_config.num_speculative_tokens, + *draft_token_ids.shape), + dtype=draft_token_ids.dtype) + draft_token_ids_tensor[0] = draft_token_ids + + positions_cpu = target_positions[last_token_indices].cpu().to( + torch.int64) + hidden_states = hidden_states[last_token_indices] + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size + 1] + + if self.vllm_config.speculative_config.num_speculative_tokens > 2: + raise ValueError("Speculative tokens > 2 are not supported yet.") + + attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + for now_speculative in range( + self.vllm_config.speculative_config.num_speculative_tokens - + 1): + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_tensor[now_speculative].to(device) + positions_cpu += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions_cpu >= self.vllm_config.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, + positions_cpu) + clamped_positions = clamped_positions_cpu.to(device) + + # TODO: Increment the sequence lengths. + + attn_metadata.seq_lens += 1 + # TODO: Consider max model length. + # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + # self.max_model_len) + # For the requests that exceed the max model length, we set the + # TODO: sequence length to 1 to minimize their overheads in attention. + + # Compute the slot mapping. + block_numbers = (clamped_positions_cpu // self.block_size) + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + slot_mapping_cpu = ( + block_ids * self.vllm_config.cache_config.block_size + + clamped_positions_cpu % self.block_size) + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping_cpu.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + # NOTE: ASCEND slot_mapping must on cpu + attn_metadata.slot_mapping = slot_mapping_cpu.to( + torch.int32).to(device) + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + positions = positions_cpu.to(device) + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + attn_metadata.seq_lens, attn_metadata.max_query_len, positions, + self.vllm_config.model_config.dtype, self.device) + + attn_metadata.attn_mask = attn_mask + attn_metadata.block_tables = block_table.to(device) + # Run the model. + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): + + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + ) + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size], + None) + + # TODO(wenlong): get more than one token for tree attention + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() + + # [batch_size, num_speculative_tokens] + draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) + return draft_token_ids + + def _prepare_inputs( + self, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + num_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + # [a - n1, b - n2, c - n3] -> + # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + cu_num_tokens = torch.zeros_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_target_query_lens.device, + ) + BLOCK_SIZE = 1024 + self._prepare_eagle_input_sequential( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, + block_size: int): + num_programs = len(cu_num_tokens) - 1 + for pid in range(num_programs): + start_pos = cu_num_tokens[pid].item() + end_pos = cu_num_tokens[pid + 1].item() + num_tokens = end_pos - start_pos + index_start = cu_query_lens[pid].item() + num_blocks = int( + torch.ceil(torch.tensor(num_tokens / block_size)).item()) + + for i in range(num_blocks): + offset_tensor = torch.arange(0, + block_size, + dtype=torch.int32, + device=out_tensor.device) + global_start_offset = i * block_size + target_indices = torch.tensor( + start_pos + global_start_offset, + dtype=torch.int32, + device=out_tensor.device) + offset_tensor + values_to_store = torch.tensor( + index_start, dtype=torch.int32, + device=out_tensor.device) + offset_tensor + mask = (target_indices >= start_pos) & \ + (target_indices < end_pos) & \ + (offset_tensor < num_tokens) + out_tensor[target_indices[mask]] = values_to_store[mask] diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py new file mode 100644 index 0000000000..0efe93de33 --- /dev/null +++ b/vllm_ascend/spec_decode/interface.py @@ -0,0 +1,51 @@ +import enum +from typing import Optional + +import torch +from vllm.config import VllmConfig +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + + +class SpecDcodeType(enum.Enum): + NGRAM = 0 + EAGLE = 1 + EAGLE3 = 2 + MTP = 4 + + +class Proposer: + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device = None, + runner=None): + pass + + def load_model(self, model): + """Called by load_model in model_runner""" + raise NotImplementedError + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None): + """Called by dummy_run in modle_runner""" + raise NotImplementedError + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + """Called by execute_model in model_runner""" + raise NotImplementedError diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/spec_decode/mtp_proposer.py similarity index 77% rename from vllm_ascend/worker/mtp_proposer_v1.py rename to vllm_ascend/spec_decode/mtp_proposer.py index 1ec1436372..2a691b5f5c 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -12,47 +12,223 @@ from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.utils import ProfileExecuteDuration -class MtpProposer: +class MtpProposer(Proposer): def __init__( self, vllm_config: VllmConfig, + device, runner, ): + self.name = SpecDcodeType.MTP self.vllm_config = vllm_config - self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens) - self.block_size = vllm_config.cache_config.block_size - self.hidden_size = vllm_config.model_config.get_hidden_size() + self.device = device self.runner = runner + # persistent buffers for graph self.input_ids = torch.zeros(self.runner.max_num_tokens, dtype=torch.int32, - device=self.runner.device) + device=self.device) self.positions = torch.zeros(self.runner.max_num_tokens, dtype=torch.int64, - device=self.runner.device) + device=self.device) self.hidden_states = torch.zeros( - (self.runner.max_num_tokens, self.hidden_size), + (self.runner.max_num_tokens, + vllm_config.model_config.get_hidden_size()), dtype=self.runner.dtype, - device=self.runner.device) + device=self.device) self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore self.torchair_graph_enabled = get_ascend_config( ).torchair_graph_config.enabled - @staticmethod - def prepare_inputs( + def load_model(self, model) -> None: + loader = get_model_loader(self.vllm_config.load_config) + + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + target_device = self.vllm_config.device_config.device + + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + self.model = CustomDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = next(iter(draft_attn_layer_names)) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp=None) -> None: + if not self.torchair_graph_enabled: + # TODO: adapt enable_dbo later + (num_tokens, num_tokens_across_dp, with_prefill, + _) = self.runner._get_forward_metadata_across_dp_and_pad( + num_tokens, with_prefill, False) + is_running_torchair = self.torchair_graph_enabled and \ + not with_prefill + + if is_running_torchair: + skip_attn = False + if skip_attn: + attn_metadata = None + else: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) + + input_ids = self.input_ids[:num_tokens] + positions = self.positions[:num_tokens] + previous_hidden_states = self.hidden_states[:num_tokens] + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=0): + if is_running_torchair: + assert attn_metadata is not None + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(previous_hidden_states) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static(attn_metadata.decode.input_positions) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + torchair_compiled_model( + input_ids=input_ids, + positions=positions, + previous_hidden_states=previous_hidden_states, + inputs_embeds=None, + intermediate_tensors=None, + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:], + spec_step_idx=0) + else: + self.model(input_ids=input_ids, + positions=positions, + previous_hidden_states=previous_hidden_states) + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + accepted_token_indices = None + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, accepted_token_indices, target_token_ids, \ + target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + self.runner.input_ids[:num_scheduled_tokens], + positions[:num_scheduled_tokens], + hidden_states[:num_scheduled_tokens], + attn_metadata.slot_mapping[:num_scheduled_tokens], + is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(), + ) + + draft_token_ids = self._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + token_indices=accepted_token_indices) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _prepare_inputs( + self, # [batch_size + 1] cu_target_query_lens: torch.Tensor, # [batch_size] @@ -99,7 +275,7 @@ def prepare_inputs( ) BLOCK_SIZE = 1024 - prepare_input_kernel( + self._prepare_input_kernel( token_indices, cu_target_query_lens, cu_num_tokens, @@ -111,7 +287,7 @@ def prepare_inputs( target_slot_mapping = slot_mapping[token_indices] return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping - def propose( + def _propose( self, # [num_tokens] target_token_ids: torch.Tensor, @@ -242,107 +418,6 @@ def propose( # [batch_size, 1] return draft_token_ids.view(-1, 1) - def load_model(self) -> None: - loader = get_model_loader(self.vllm_config.load_config) - - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - target_device = self.vllm_config.device_config.device - - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - self.model = CustomDeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - - draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = next(iter(draft_attn_layer_names)) - - self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) - process_weights_after_loading(self.model, draft_model_config, - target_device) - - @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - skip_attn: bool = False, - num_reqs: int = 0, - num_tokens_across_dp=None) -> None: - if not self.torchair_graph_enabled: - # TODO: adapt enable_dbo later - (num_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) - is_running_torchair = self.torchair_graph_enabled and \ - not with_prefill - - if is_running_torchair: - skip_attn = False - if skip_attn: - attn_metadata = None - else: - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=num_reqs, - num_actual_tokens=1, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - decode_token_per_req=self.runner.decode_token_per_req, - ) - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - common_attn_metadata) - - input_ids = self.input_ids[:num_tokens] - positions = self.positions[:num_tokens] - previous_hidden_states = self.hidden_states[:num_tokens] - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=0): - if is_running_torchair: - assert attn_metadata is not None - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(previous_hidden_states) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static(attn_metadata.decode.input_positions) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(get_forward_context().mc2_mask) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - torch._dynamo.mark_static(attn_metadata.decode.attn_mask) - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_tokens) - torchair_compiled_model( - input_ids=input_ids, - positions=positions, - previous_hidden_states=previous_hidden_states, - inputs_embeds=None, - intermediate_tensors=None, - attn_metadata=attn_metadata, - kv_caches=self.runner.kv_caches[-1:], - spec_step_idx=0) - else: - self.model(input_ids=input_ids, - positions=positions, - previous_hidden_states=previous_hidden_states) - def _get_torchair_lazy_compiled_model(self, batch_size: int): if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ -1]: @@ -397,23 +472,23 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): ge_cache=False) return self.torchair_compiled_models[batch_size] + # TODO Using torch instead of triton may result in poor performance + def _prepare_input_kernel(self, out_ptr: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, block_size: int): + device = cu_query_lens.device + dtype = out_ptr.dtype -# TODO Using torch instead of triton may result in poor performance -def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, block_size: int): - device = cu_query_lens.device - dtype = out_ptr.dtype - - offsets = torch.arange(block_size, device=device, dtype=dtype) - start_pos = cu_num_tokens[:-1] - end_pos = cu_num_tokens[1:] - num_tokens = end_pos - start_pos + offsets = torch.arange(block_size, device=device, dtype=dtype) + start_pos = cu_num_tokens[:-1] + end_pos = cu_num_tokens[1:] + num_tokens = end_pos - start_pos - global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) - values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) + global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) + values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) - mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) + mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) - global_indices_flat = global_indices[mask] - values_flat = values[mask] - out_ptr[global_indices_flat] = values_flat + global_indices_flat = global_indices[mask] + values_flat = values[mask] + out_ptr[global_indices_flat] = values_flat diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py new file mode 100644 index 0000000000..f4a9337be6 --- /dev/null +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -0,0 +1,65 @@ +import torch +from vllm.v1.spec_decode.ngram_proposer import \ + NgramProposer as VllmNgramProposer + +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + + +class NgramProposer(VllmNgramProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.NGRAM + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + # TODO(woosuk): Optimize. + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(valid_sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require top-p, top-k, etc. + req_id = self.input_batch.req_ids[i] + if req_id in self.input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + num_sampled_ids + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + assert isinstance(self.drafter, NgramProposer) + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :end_idx]) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py deleted file mode 100644 index 895649327c..0000000000 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ /dev/null @@ -1,400 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import os - -import torch -import torch.nn as nn -from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) -from vllm.distributed.parallel_state import get_pp_group -from vllm.logger import logger -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import supports_multimodal -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.v1.sample.metadata import SamplingMetadata - -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata - -PADDING_SLOT_ID = -1 - - -class EagleProposer: - - def __init__(self, - vllm_config: VllmConfig, - device: torch.device, - runner=None): - self.vllm_config = vllm_config - self.speculative_config = vllm_config.speculative_config - self.draft_model_config = self.speculative_config.draft_model_config - self.method = self.speculative_config.method - self.runner = runner - self.model_config = vllm_config.model_config - self.dtype = vllm_config.model_config.dtype - self.max_model_len = vllm_config.model_config.max_model_len - self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.device = device - # We need to get the hidden size from the draft model config because - # the draft model's hidden size can be different from the target model's - # hidden size (e.g., Llama 3.3 70B). - self.hidden_size = self.draft_model_config.get_hidden_size() - - self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) - - # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) - self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, - device=device, - dtype=torch.int32) - mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) - self.attn_mask_len = min(self.model_config.max_model_len, - int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder(self.attn_mask_len, - self.dtype) - - def _make_attention_mask( - self, - seq_lens, - query_lens, - position, - ) -> torch.Tensor: - return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, query_lens, position, self.dtype, self.device) - - def propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - device = cu_num_tokens.device - cu_num_tokens = cu_num_tokens.cpu() - block_table = block_table.cpu() - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - target_positions = target_positions.cpu() - if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) - assert target_hidden_states.shape[-1] == self.hidden_size - - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids[0] - - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc[:batch_size + 1], - query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + - 1], - seq_lens_cpu=self.runner.seq_lens_cpu, - max_query_len=max_query_len, - num_reqs=batch_size, - num_actual_tokens=num_tokens, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=target_slot_mapping, - positions=target_positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - ) - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build( - common_attn_metadata, self.runner.model) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens - # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions.to(device) - self.hidden_states[:num_tokens] = target_hidden_states - attn_metadata.block_tables = block_table.to(device) - with set_ascend_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], - ) - sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids = logits.argmax(dim=-1) - - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # Generate the remaining draft tokens. - draft_token_ids_tensor = torch.zeros( - (self.num_speculative_tokens, *draft_token_ids.shape), - dtype=draft_token_ids.dtype) - draft_token_ids_tensor[0] = draft_token_ids - - positions_cpu = target_positions[last_token_indices].cpu().to( - torch.int64) - hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - - if self.num_speculative_tokens > 2: - raise ValueError("Speculative tokens > 2 are not supported yet.") - - attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill - for now_speculative in range(self.num_speculative_tokens - 1): - # Update the inputs. - # cast to int32 is crucial when eagle model is compiled. - # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_tensor[now_speculative].to(device) - positions_cpu += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions_cpu >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, - positions_cpu) - clamped_positions = clamped_positions_cpu.to(device) - - # TODO: Increment the sequence lengths. - - attn_metadata.seq_lens += 1 - # TODO: Consider max model length. - # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - # self.max_model_len) - # For the requests that exceed the max model length, we set the - # TODO: sequence length to 1 to minimize their overheads in attention. - - # Compute the slot mapping. - block_numbers = (clamped_positions_cpu // self.block_size) - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - slot_mapping_cpu = (block_ids * self.block_size + - clamped_positions_cpu % self.block_size) - - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - slot_mapping_cpu.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) - # NOTE: ASCEND slot_mapping must on cpu - attn_metadata.slot_mapping = slot_mapping_cpu.to( - torch.int32).to(device) - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - self.hidden_states[:batch_size] = hidden_states - positions = positions_cpu.to(device) - attn_mask = self._make_attention_mask( - seq_lens=attn_metadata.seq_lens, - query_lens=attn_metadata.max_query_len, - position=positions, - ) - attn_metadata.attn_mask = attn_mask - attn_metadata.block_tables = block_table.to(device) - # Run the model. - with set_ascend_forward_context(attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): - - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], - ) - hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) - - # TODO(wenlong): get more than one token for tree attention - draft_token_ids = logits.argmax(dim=-1) - draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() - - # [batch_size, num_speculative_tokens] - draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) - return draft_token_ids - - @staticmethod - def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, - num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - - # [a - n1, b - n2, c - n3] -> - # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - cu_num_tokens = torch.zeros_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_target_query_lens.device, - ) - BLOCK_SIZE = 1024 - prepare_eagle_input_sequential( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) - return cu_num_tokens, token_indices - - def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) - - draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - - self.attn_layer_names = list(draft_attn_layer_names) - self.attn_layer_name = next(iter(draft_attn_layer_names)) - # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1: - logger.info( - "The EAGLE head shares the same vocab embedding" \ - " with the target model." - ) - self.model.model.embed_tokens = target_model.model.embed_tokens - else: - logger.info( - "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." - ) - - # share lm_head with the target model if needed - # some model definition do not define lm_head explicitly - # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - if supports_multimodal(target_model): - self.model.lm_head = target_model.get_language_model().lm_head - else: - self.model.lm_head = target_model.lm_head - - @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - ) -> None: - with set_ascend_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): - self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - ) - - -def prepare_eagle_input_sequential(out_tensor: torch.Tensor, - cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, - block_size: int): - num_programs = len(cu_num_tokens) - 1 - for pid in range(num_programs): - start_pos = cu_num_tokens[pid].item() - end_pos = cu_num_tokens[pid + 1].item() - num_tokens = end_pos - start_pos - index_start = cu_query_lens[pid].item() - num_blocks = int( - torch.ceil(torch.tensor(num_tokens / block_size)).item()) - - for i in range(num_blocks): - offset_tensor = torch.arange(0, - block_size, - dtype=torch.int32, - device=out_tensor.device) - global_start_offset = i * block_size - target_indices = torch.tensor( - start_pos + global_start_offset, - dtype=torch.int32, - device=out_tensor.device) + offset_tensor - values_to_store = torch.tensor( - index_start, dtype=torch.int32, - device=out_tensor.device) + offset_tensor - mask = (target_indices >= start_pos) & \ - (target_indices < end_pos) & \ - (offset_tensor < num_tokens) - out_tensor[target_indices[mask]] = values_to_store[mask] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 03486b0555..fbef29b2fd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -91,13 +91,15 @@ from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler +from vllm_ascend.spec_decode import get_spec_decode_method +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, vllm_version_is) -from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer -from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if not vllm_version_is("0.10.1.1"): @@ -237,16 +239,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) # Set up speculative decoding. - self.use_aux_hidden_state_outputs = False - self.use_spec_decode = False self.spec_attn_mask = None - self.use_eagle = False self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None self.actual_seq_lengths_q = [] self.decode_token_per_req = 1 if self.speculative_config: - self.use_spec_decode = True spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num @@ -260,19 +258,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.bool), diagonal=1).to(self.device) if get_pp_group().is_last_rank: - if self.speculative_config.method == "ngram": - self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method in ["eagle", "eagle3"]: - self.use_eagle = True - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore - if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True - elif self.speculative_config.method == 'deepseek_mtp': - self.drafter = MtpProposer(self.vllm_config, self) - else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + self.drafter = get_spec_decode_method( + self.speculative_config.method, self.vllm_config, + self.device, self) self.rejection_sampler = AscendRejectionSampler() # Persistent batch. @@ -640,152 +628,6 @@ def _check_dbo_is_valid(self, query_lens: torch.Tensor, return False return True - def get_eagle_atten_dict( - self, - scheduler_output: "SchedulerOutput", - ) -> dict[str, Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, AscendMLATorchairMetadata]]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit_block_table(num_reqs) - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) - self.query_lens = torch.from_numpy(num_scheduled_tokens) - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) - - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - # NOTE(Chen): there is exactly one KV cache group that contains all - # attetnion layers in the model for now, so the current logic for - # getting attn_metadata is not related to kv_cache_group information. - # Will extend this part to support multiple KV cache groups later. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table = self.input_batch.block_table[kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) - - # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - - # Copy the tensors to the NPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - if self.uses_mrope: - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) - else: - # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) - - self.query_start_loc[:num_reqs + 1].copy_( - self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) - - # Fill unused with -1. Needed for reshape_and_cache - self.seq_lens[num_reqs:].fill_(0) - self.query_start_loc[num_reqs + 1:].fill_(-1) - - attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, - AscendMLATorchairMetadata]] = {} - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens_cpu, - num_reqs=num_reqs, - max_query_len=max_num_scheduled_tokens, - num_actual_tokens=total_num_scheduled_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=self.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=self.slot_mapping_cpu, - positions=self.positions, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, - decode_token_per_req=self.decode_token_per_req, - ) - attn_metadata_i = self.attn_metadata_builder.build( - common_attn_metadata, self.get_model()) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - return attn_metadata - def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. if isinstance(self.model, ACLGraphWrapper): @@ -1328,7 +1170,8 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens, attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): - if self.use_eagle: + if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE + or self.drafter.name == SpecDcodeType.EAGLE3): attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.SpecDecoding @@ -1349,26 +1192,6 @@ def _update_input_ids_and_positions(self, input_ids, positions, positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions - def _get_cumsum_and_arange( - self, - num_tokens: np.ndarray, - cumsum_dtype: Optional[np.dtype] = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Get the cumulative sum and batched arange of the given array. - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_tokens]) - """ - # Step 1. [2, 5, 3] -> [2, 7, 10] - cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) - total_num_tokens = cu_num_tokens[-1] - # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] - cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) - # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = self.arange_np[:total_num_tokens] - cumsums_offsets - - return cu_num_tokens, arange - def _calc_spec_decode_metadata( self, num_draft_tokens: np.ndarray, @@ -1522,24 +1345,14 @@ def propose_draft_token_ids( AscendMLATorchairMetadata], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: - if not self.use_spec_decode: + if not self.drafter: # Speculative decoding is not enabled. draft_token_ids = None - elif self.speculative_config.method == "ngram": - draft_token_ids = self._generate_ngram_token_ids( - valid_sampled_token_ids) - elif self.speculative_config.method == "eagle": - raise NotImplementedError("Eagle Is Not Supported Yet.") - elif self.speculative_config.method == "eagle3": - draft_token_ids = self._generate_eagle3_token_ids( - valid_sampled_token_ids, sampling_metadata, scheduler_output, - spec_decode_metadata, positions, num_scheduled_tokens, - hidden_states, aux_hidden_states) - elif self.speculative_config.method == 'deepseek_mtp': - draft_token_ids = self._generate_mtp_token_ids( + else: + draft_token_ids = self.drafter.generate_token_ids( valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, num_scheduled_tokens, - hidden_states, attn_metadata) + hidden_states, attn_metadata, aux_hidden_states) return draft_token_ids def _pool( @@ -1647,7 +1460,7 @@ def execute_model( scheduler_output) aux_hidden_states = None - if self.use_aux_hidden_state_outputs: + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, aux_hidden_states = hidden_states kv_connector_output = None @@ -1924,12 +1737,10 @@ def _generate_dummy_run_hidden_states(self, with_prefill, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) - if self.use_aux_hidden_state_outputs: + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states else: hidden_states = hidden_states - if self.use_spec_decode and isinstance(self.drafter, EagleProposer): - self.drafter.dummy_run(num_tokens) return hidden_states @torch.inference_mode() @@ -2070,8 +1881,7 @@ def _dummy_run( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) - if self.speculative_config and self.speculative_config.method == "deepseek_mtp": - assert isinstance(self.drafter, MtpProposer) + if self.drafter: self.drafter.dummy_run( num_tokens=num_tokens, with_prefill=with_prefill, @@ -2202,13 +2012,11 @@ def load_model(self) -> None: module.weight.data) if self.drafter: logger.info("Loading drafter model...") - if isinstance(self.drafter, EagleProposer): - if self.use_aux_hidden_state_outputs: - self.drafter.load_model(self.model) - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: - self.drafter.load_model() + self.drafter.load_model(self.model) + if self.drafter.name == SpecDcodeType.EAGLE3: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -2498,193 +2306,6 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) - def _generate_ngram_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - # TODO(woosuk): Optimize. - draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require top-p, top-k, etc. - req_id = self.input_batch.req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - # Add sampled_token_ids to token_ids_cpu. - start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + num_sampled_ids - self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - assert isinstance(self.drafter, NgramProposer) - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :end_idx]) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) - return draft_token_ids - - def _generate_eagle3_token_ids(self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - scheduler_output: "SchedulerOutput", - spec_decode_metadata: SpecDecodeMetadata, - positions: torch.Tensor, - num_scheduled_tokens: int, - hidden_states: torch.Tensor, - aux_hidden_states: torch.Tensor = None): - assert isinstance(self.drafter, EagleProposer) - attn_metadata = self.get_eagle_atten_dict(scheduler_output) - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc - else: - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) - cu_num_tokens, token_indices = self.drafter.prepare_inputs( - eagle_attn_metadata.query_start_loc, num_rejected_tokens, - num_tokens) - target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] - - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids - - def _generate_mtp_token_ids( - self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - scheduler_output: "SchedulerOutput", - spec_decode_metadata: SpecDecodeMetadata, - positions: torch.Tensor, - num_scheduled_tokens: int, - hidden_states: torch.Tensor, - attn_metadata: Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, - AscendMLATorchairMetadata], - ): - assert isinstance(self.drafter, MtpProposer) - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - accepted_token_indices = None - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc - else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - cu_num_tokens, accepted_token_indices, target_token_ids, \ - target_positions, target_hidden_states, target_slot_mapping = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, - num_rejected_tokens, - self.input_ids[:num_scheduled_tokens], - positions[:num_scheduled_tokens], - hidden_states[:num_scheduled_tokens], - attn_metadata.slot_mapping[:num_scheduled_tokens], - is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(), - ) - - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - token_indices=accepted_token_indices) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids - def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor,