Skip to content

Commit 83e0f41

Browse files
[3/N][Refactor] Move torchair_attention to torchair dir (#2017)
### What this PR does / why we need it? 1. Move `torchair_attention` to `torchair` dir. 2. Make `AscendAttentionTorchairBackend` extend `AscendAttentionBackend` to reduce duplicate methods. 3. Make `AscendTorchairMetadata` extend `AscendMetadata` to reduce duplicate properties. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@0933f9d --------- Signed-off-by: shen-shanshan <[email protected]>
1 parent 2a763b8 commit 83e0f41

File tree

5 files changed

+24
-75
lines changed

5 files changed

+24
-75
lines changed

tests/ut/test_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def test_get_attn_backend_cls_use_v1_and_torchair(self,
444444
)
445445
self.assertEqual(
446446
result,
447-
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
447+
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
448448
)
449449

450450
@patch('vllm_ascend.platform.get_ascend_config')

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def build(self,
169169
num_actual_tokens,
170170
max_query_len,
171171
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False):
172+
is_only_prefill: bool = False,
173+
*args,
174+
**kwargs):
173175

174176
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
175177
)

vllm_ascend/platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def get_attn_backend_cls(cls,
218218
if use_mla:
219219
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
220220
elif use_torchair:
221-
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
221+
return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
222222
else:
223223
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
224224

vllm_ascend/attention/attention_v1_torchair.py renamed to vllm_ascend/torchair/torchair_attention.py

Lines changed: 18 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,19 @@
2121
import numpy as np
2222
import torch
2323
import torch_npu
24-
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
25-
AttentionLayer, AttentionType)
26-
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
27-
from vllm.v1.core.sched.output import SchedulerOutput
28-
29-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
24+
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
25+
AttentionType)
26+
from vllm.attention.backends.utils import PAD_SLOT_ID
27+
28+
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
29+
AscendAttentionMetadataBuilder,
30+
AscendAttentionState,
31+
AscendMetadata)
3032
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3133
nd_to_nz_2d)
32-
from vllm_ascend.worker.npu_input_batch import InputBatch
3334

3435

35-
class AscendAttentionTorchairBackend(AttentionBackend):
36+
class AscendAttentionTorchairBackend(AscendAttentionBackend):
3637
accept_output_buffer: bool = True
3738

3839
@staticmethod
@@ -47,10 +48,6 @@ def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
4748
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
4849
return AscendTorchairMetadata
4950

50-
@staticmethod
51-
def get_state_cls() -> Type["CommonAttentionState"]:
52-
return CommonAttentionState
53-
5451
@staticmethod
5552
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
5653
return AscendAttentionTorchairMetadataBuilder
@@ -73,36 +70,6 @@ def get_bsh_kv_cache_shape(
7370
) -> Tuple[int, ...]:
7471
return (2, num_blocks, block_size, num_kv_heads * head_size)
7572

76-
@staticmethod
77-
def swap_blocks(
78-
src_kv_cache: List[torch.Tensor],
79-
dst_kv_cache: List[torch.Tensor],
80-
src_to_dst: torch.Tensor,
81-
) -> None:
82-
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
83-
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
84-
src_indices = src_to_dst[:, 0]
85-
dst_indices = src_to_dst[:, 1]
86-
87-
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
88-
dst_key_cache.device)
89-
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
90-
dst_key_cache.device)
91-
92-
@staticmethod
93-
def copy_blocks(
94-
kv_caches: List[torch.Tensor],
95-
src_to_dists: torch.Tensor,
96-
) -> None:
97-
src_indices = src_to_dists[:, 0]
98-
dst_indices = src_to_dists[:, 1]
99-
100-
for kv_cache in kv_caches:
101-
key_caches = kv_cache[0]
102-
value_caches = kv_cache[1]
103-
key_caches[dst_indices] = key_caches[src_indices]
104-
value_caches[dst_indices] = value_caches[src_indices]
105-
10673

10774
@dataclass
10875
class AscendDecodeMetadata:
@@ -117,40 +84,15 @@ class AscendDecodeMetadata:
11784

11885

11986
@dataclass
120-
class AscendTorchairMetadata:
121-
num_actual_tokens: int # Number of tokens excluding padding.
122-
# (batch_size, max_blocks_per_seq).
123-
# Block addresses per sequence. (Seq id -> list of physical block)
124-
block_tables: torch.Tensor
125-
# (batch_size,). The sequence length per sequence. Sequence length means
126-
# the computed tokens + new tokens None if it is a decoding.
127-
query_start_loc: torch.Tensor
128-
query_lens: torch.Tensor
129-
seq_lens: torch.Tensor
130-
# Maximum query length in the batch. None for decoding.
131-
max_query_len: Optional[int] = None
132-
# (num_tokens,). The indices of the token slots that input tokens will be
133-
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
134-
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
135-
# in block 0, and 1st slot in block 1, respectively.
136-
slot_mapping: torch.Tensor = None
137-
# Current state of this attention run.
138-
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
139-
attn_mask: Optional[torch.Tensor] = None
87+
class AscendTorchairMetadata(AscendMetadata):
14088

14189
decode: Optional[AscendDecodeMetadata] = None
14290

143-
enable_dbo_across_dp: bool = False
14491

145-
146-
class AscendAttentionTorchairMetadataBuilder:
92+
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
14793

14894
def __init__(self, runner):
149-
self.runner = runner
150-
151-
def reorder_batch(self, input_batch: "InputBatch",
152-
scheduler_output: "SchedulerOutput") -> bool:
153-
return False
95+
super().__init__(runner)
15496

15597
def _get_graph_runner_block_tables(
15698
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
@@ -222,11 +164,16 @@ def build(self,
222164
num_reqs,
223165
num_actual_tokens,
224166
max_query_len,
225-
graph_pad_size: int = -1,
226167
enable_dbo_across_dp: bool = False,
168+
is_only_prefill: bool = False,
227169
*args,
228170
**kwargs):
229171

172+
if 'graph_pad_size' in kwargs:
173+
graph_pad_size = kwargs['graph_pad_size']
174+
else:
175+
graph_pad_size = -1 # default value
176+
230177
device = self.runner.device
231178

232179
block_table = self.runner.input_batch.block_table[0].get_device_tensor(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@
7878
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
7979
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
8080
AscendMetadata)
81-
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
8281
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
8382
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
8483
DummyCommImpl,
8584
MoECommMethod)
8685
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8786
from vllm_ascend.platform import NPUPlatform
8887
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
88+
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
8989
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
9090
ProfileExecuteDuration, is_310p,
9191
maybe_converting_weight_acl_format)

0 commit comments

Comments
 (0)