Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 70 additions & 23 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ class AscendMetadata:
is_only_prefill: bool = False


@dataclass
class AscendAttentionMetadataBuildInfo:
num_actual_tokens: int
block_table: torch.Tensor
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
attn_mask: torch.Tensor
attn_state: AscendAttentionState


class AscendAttentionMetadataBuilder:

def __init__(
Expand All @@ -175,9 +187,60 @@ def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False

def _assemble_build_info(
self,
num_actual_tokens,
block_table,
query_start_loc,
query_lens,
seq_lens,
slot_mapping,
attn_mask,
attn_state: "AscendAttentionState",
) -> "AscendAttentionMetadataBuildInfo":
if is_310p():
if attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
elif attn_state == AscendAttentionState.ChunkedPrefill:
mask_nz = nd_to_nz_spec(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)

build_info = AscendAttentionMetadataBuildInfo(
num_actual_tokens=num_actual_tokens,
block_table=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
return build_info

def _assemble_attn_metadata(
self,
build_info: "AscendAttentionMetadataBuildInfo",
common_attn_metadata: "AscendCommonAttentionMetadata",
) -> "AscendMetadata":
attn_metadata = AscendMetadata(
num_actual_tokens=build_info.num_actual_tokens,
block_tables=build_info.block_table,
query_start_loc=build_info.query_start_loc,
query_lens=build_info.query_lens,
seq_lens=build_info.seq_lens,
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=build_info.slot_mapping,
attn_mask=build_info.attn_mask,
attn_state=build_info.attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_only_prefill=common_attn_metadata.is_only_prefill)
return attn_metadata

def build(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
common_attn_metadata: "AscendCommonAttentionMetadata",
model: nn.Module,
):
num_reqs = common_attn_metadata.num_reqs
Expand Down Expand Up @@ -205,28 +268,12 @@ def build(
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)

if is_310p():
if attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
elif attn_state == AscendAttentionState.ChunkedPrefill:
mask_nz = nd_to_nz_spec(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_only_prefill=common_attn_metadata.is_only_prefill)
build_info = self._assemble_build_info(num_actual_tokens, block_table,
query_start_loc, query_lens,
seq_lens, slot_mapping,
attn_mask, attn_state)
attn_metadata = self._assemble_attn_metadata(build_info,
common_attn_metadata)
return attn_metadata


Expand Down
76 changes: 41 additions & 35 deletions vllm_ascend/torchair/torchair_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@

import numpy as np
import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.utils import cdiv

from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1 import (
AscendAttentionBackend, AscendAttentionMetadataBuilder,
AscendAttentionMetadataBuildInfo, AscendAttentionState, AscendMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
Expand Down Expand Up @@ -169,44 +167,52 @@ def build_torchair_graph_dummy(
decode=decode_metadata)
return attn_metadata

def build(
def _assemble_build_info(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens

block_table = common_attn_metadata.block_table_tensor
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_table[:num_reqs])

seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
attn_mask = common_attn_metadata.attn_mask

attn_state = common_attn_metadata.attn_state
num_actual_tokens,
block_table,
query_start_loc,
query_lens,
seq_lens,
slot_mapping,
attn_mask,
attn_state: "AscendAttentionState",
) -> "AscendAttentionMetadataBuildInfo":
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)

query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
build_info = AscendAttentionMetadataBuildInfo(
num_actual_tokens=num_actual_tokens,
block_table=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
return build_info

def _assemble_attn_metadata(
self,
build_info: "AscendAttentionMetadataBuildInfo",
common_attn_metadata: "AscendCommonAttentionMetadata",
) -> "AscendMetadata":
num_actual_tokens = build_info.num_actual_tokens
block_table = build_info.block_table
seq_lens = build_info.seq_lens
slot_mapping = build_info.slot_mapping
attn_state = build_info.attn_state

num_reqs = common_attn_metadata.num_reqs
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
graph_pad_size = common_attn_metadata.graph_pad_size

decode_metadata = None
graph_pad_size = common_attn_metadata.graph_pad_size
use_torchair_graph = graph_pad_size > -1

if common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
]:
Expand Down Expand Up @@ -259,12 +265,12 @@ def build(
decode=decode_metadata,
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
query_start_loc=build_info.query_start_loc,
query_lens=build_info.query_lens,
seq_lens=seq_lens,
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_mask=build_info.attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
return attn_metadata
Expand Down
Loading