Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 81 additions & 62 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import is_pin_memory_available
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import vllm_version_is
Expand Down Expand Up @@ -65,6 +67,9 @@ def __init__(self,
self.hidden_size),
dtype=self.vllm_config.model_config.dtype,
device=device)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
Expand Down Expand Up @@ -182,10 +187,8 @@ def generate_token_ids(self,
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self._prepare_inputs(
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
cu_num_tokens, token_indices =\
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
target_token_ids = self.runner.input_ids[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
Expand Down Expand Up @@ -603,72 +606,88 @@ def _propose(

def _prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
eagle_attn_metadata: AscendMetadata,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]

# [0, a, a + b, a + b + c] -> [a, b, c]
"""
This function is used to prepare the inputs for the spec decode.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
cu_target_query_lens = eagle_attn_metadata.query_start_loc
device = eagle_attn_metadata.query_start_loc.device
query_start_loc_cpu = cu_target_query_lens.to("cpu")

# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded

# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)

# need use npu
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens

# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
Comment on lines +683 to 691
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of cu_num_tokens on the device is redundant. The same value has already been computed on the CPU as new_query_start_loc_cpu. Instead of re-calculating it on the device, you can simply transfer the CPU tensor to the device. This avoids unnecessary computation and better aligns with the goal of optimizing performance.

Suggested change
# need use npu
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens = new_query_start_loc_cpu.to(device, non_blocking=True)

token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_target_query_lens.device,
)
BLOCK_SIZE = 1024
self._prepare_eagle_input_sequential(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
return cu_num_tokens, token_indices

def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
cu_query_lens: torch.Tensor,
cu_num_tokens: torch.Tensor,
block_size: int):
num_programs = len(cu_num_tokens) - 1
for pid in range(num_programs):
start_pos = cu_num_tokens[pid].item()
end_pos = cu_num_tokens[pid + 1].item()
num_tokens = end_pos - start_pos
index_start = cu_query_lens[pid].item()
num_blocks = int(
torch.ceil(torch.tensor(num_tokens / block_size)).item())

for i in range(num_blocks):
offset_tensor = torch.arange(0,
block_size,
dtype=torch.int32,
device=out_tensor.device)
global_start_offset = i * block_size
target_indices = torch.tensor(
start_pos + global_start_offset,
dtype=torch.int32,
device=out_tensor.device) + offset_tensor
values_to_store = torch.tensor(
index_start + global_start_offset,
dtype=torch.int32,
device=out_tensor.device) + offset_tensor
mask = (target_indices >= start_pos) & \
(target_indices < end_pos) & \
(offset_tensor < num_tokens)
out_tensor[target_indices[mask]] = values_to_store[mask]
return cu_num_tokens, token_indices
Loading