Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 60 additions & 36 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch

from vllm_ascend.attention.utils import \
AscendCommonAttentionMetadata as CommonAttentionMetadata
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import get_graph_params
Expand Down Expand Up @@ -156,39 +155,49 @@ def split_metadata_for_multistream(

class AscendAttentionMetadataBuilder:

def __init__(self, runner):
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
self.runner = runner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove runner


def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False

def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
common_attn_metadata: CommonAttentionMetadata,
enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False,
*args,
**kwargs):

block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])

query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
def build(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]

block_table = common_attn_metadata.block_table_tensor
block_table[:num_reqs, :common_attn_metadata.
max_num_blocks_per_req] = (block_table[:num_reqs])

seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
# TODO: Refactor these two param to common metadata in runners,
# preparing for the hybrid KV groups feature
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# Since FIA for GQA is not active now, we temporarily silence it
seq_lens_list = common_attn_metadata.seq_lens_list

slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
attn_mask = self.runner.attn_mask
attn_state = self.runner.attn_state
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
Expand All @@ -197,34 +206,49 @@ def build(self,
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_query_len=max_query_len,
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp,
is_only_prefill=is_only_prefill)
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_only_prefill=common_attn_metadata.is_only_prefill)
return attn_metadata

def build_dummy_metadata(self, num_actual_tokens, num_reqs,
num_scheduled_tokens, attn_state):
if attn_state == AscendAttentionState.DecodeOnly:
# NOTE: We only need to pay attention to seq_lens_list and block_table here
common_attn_metadata = CommonAttentionMetadata(
seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2))

block_table = self.runner.input_batch.block_table[0].block_table
block_table[:num_reqs, 0] = torch.arange(1,
num_reqs + 1,
device=block_table.device,
dtype=block_table.dtype)
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])

attn_metadata = self.build(
num_reqs=num_reqs,
query_start_loc = None
seq_lens = torch.empty_like(self.runner.seq_lens_cpu).fill_(2)
query_lens = self.runner.query_lens
seq_lens_list = None

slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
attn_mask = self.runner.attn_mask

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_query_len=num_scheduled_tokens.max(),
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=False,
is_only_prefill=False)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
Expand Down
Loading