diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index d14dc6d2a4..11f488d2b0 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -12,13 +12,15 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.utils import is_pin_memory_available from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, + AscendMetadata) from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import vllm_version_is @@ -65,6 +67,9 @@ def __init__(self, self.hidden_size), dtype=self.vllm_config.model_config.dtype, device=device) + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + self.token_arange_np = np.arange(self.max_num_tokens) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + @@ -182,10 +187,8 @@ def generate_token_ids(self, dtype=torch.int32, device=self.device, ) - num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) - cu_num_tokens, token_indices = self._prepare_inputs( - eagle_attn_metadata.query_start_loc, num_rejected_tokens, - num_tokens) + cu_num_tokens, token_indices =\ + self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens) target_token_ids = self.runner.input_ids[token_indices] target_positions = positions[token_indices] if self.name == SpecDcodeType.EAGLE3: @@ -603,72 +606,88 @@ def _propose( def _prepare_inputs( self, - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, + eagle_attn_metadata: AscendMetadata, # [batch_size] num_rejected_tokens: torch.Tensor, - num_tokens: int, ) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] + """ + This function is used to prepare the inputs for the spec decode. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + num_rejected_tokens_cpu = num_rejected_tokens.to("cpu") + cu_target_query_lens = eagle_attn_metadata.query_start_loc + device = eagle_attn_metadata.query_start_loc.device + query_start_loc_cpu = cu_target_query_lens.to("cpu") + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + # need use npu query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens # [a - n1, b - n2, c - n3] -> # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] cu_num_tokens = torch.zeros_like(cu_target_query_lens) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_target_query_lens.device, - ) - BLOCK_SIZE = 1024 - self._prepare_eagle_input_sequential( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) - return cu_num_tokens, token_indices - def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor, - cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, - block_size: int): - num_programs = len(cu_num_tokens) - 1 - for pid in range(num_programs): - start_pos = cu_num_tokens[pid].item() - end_pos = cu_num_tokens[pid + 1].item() - num_tokens = end_pos - start_pos - index_start = cu_query_lens[pid].item() - num_blocks = int( - torch.ceil(torch.tensor(num_tokens / block_size)).item()) - - for i in range(num_blocks): - offset_tensor = torch.arange(0, - block_size, - dtype=torch.int32, - device=out_tensor.device) - global_start_offset = i * block_size - target_indices = torch.tensor( - start_pos + global_start_offset, - dtype=torch.int32, - device=out_tensor.device) + offset_tensor - values_to_store = torch.tensor( - index_start + global_start_offset, - dtype=torch.int32, - device=out_tensor.device) + offset_tensor - mask = (target_indices >= start_pos) & \ - (target_indices < end_pos) & \ - (offset_tensor < num_tokens) - out_tensor[target_indices[mask]] = values_to_store[mask] + return cu_num_tokens, token_indices