21
21
import numpy as np
22
22
import torch
23
23
import torch_npu
24
- from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
- AttentionLayer , AttentionType )
26
- from vllm .attention .backends .utils import PAD_SLOT_ID , CommonAttentionState
27
- from vllm .v1 .core .sched .output import SchedulerOutput
28
-
29
- from vllm_ascend .attention .attention_v1 import AscendAttentionState
24
+ from vllm .attention .backends .abstract import (AttentionImpl , AttentionLayer ,
25
+ AttentionType )
26
+ from vllm .attention .backends .utils import PAD_SLOT_ID
27
+
28
+ from vllm_ascend .attention .attention_v1 import (AscendAttentionBackend ,
29
+ AscendAttentionMetadataBuilder ,
30
+ AscendAttentionState ,
31
+ AscendMetadata )
30
32
from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
31
33
nd_to_nz_2d )
32
- from vllm_ascend .worker .npu_input_batch import InputBatch
33
34
34
35
35
- class AscendAttentionTorchairBackend (AttentionBackend ):
36
+ class AscendAttentionTorchairBackend (AscendAttentionBackend ):
36
37
accept_output_buffer : bool = True
37
38
38
39
@staticmethod
@@ -47,10 +48,6 @@ def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
47
48
def get_metadata_cls () -> Type ["AscendTorchairMetadata" ]:
48
49
return AscendTorchairMetadata
49
50
50
- @staticmethod
51
- def get_state_cls () -> Type ["CommonAttentionState" ]:
52
- return CommonAttentionState
53
-
54
51
@staticmethod
55
52
def get_builder_cls () -> type ["AscendAttentionTorchairMetadataBuilder" ]:
56
53
return AscendAttentionTorchairMetadataBuilder
@@ -73,36 +70,6 @@ def get_bsh_kv_cache_shape(
73
70
) -> Tuple [int , ...]:
74
71
return (2 , num_blocks , block_size , num_kv_heads * head_size )
75
72
76
- @staticmethod
77
- def swap_blocks (
78
- src_kv_cache : List [torch .Tensor ],
79
- dst_kv_cache : List [torch .Tensor ],
80
- src_to_dst : torch .Tensor ,
81
- ) -> None :
82
- src_key_cache , src_value_cache = src_kv_cache [0 ], src_kv_cache [1 ]
83
- dst_key_cache , dst_value_cache = dst_kv_cache [0 ], dst_kv_cache [1 ]
84
- src_indices = src_to_dst [:, 0 ]
85
- dst_indices = src_to_dst [:, 1 ]
86
-
87
- dst_key_cache [dst_indices ] = src_key_cache [src_indices ].to (
88
- dst_key_cache .device )
89
- dst_value_cache [dst_indices ] = src_value_cache [src_indices ].to (
90
- dst_key_cache .device )
91
-
92
- @staticmethod
93
- def copy_blocks (
94
- kv_caches : List [torch .Tensor ],
95
- src_to_dists : torch .Tensor ,
96
- ) -> None :
97
- src_indices = src_to_dists [:, 0 ]
98
- dst_indices = src_to_dists [:, 1 ]
99
-
100
- for kv_cache in kv_caches :
101
- key_caches = kv_cache [0 ]
102
- value_caches = kv_cache [1 ]
103
- key_caches [dst_indices ] = key_caches [src_indices ]
104
- value_caches [dst_indices ] = value_caches [src_indices ]
105
-
106
73
107
74
@dataclass
108
75
class AscendDecodeMetadata :
@@ -117,40 +84,15 @@ class AscendDecodeMetadata:
117
84
118
85
119
86
@dataclass
120
- class AscendTorchairMetadata :
121
- num_actual_tokens : int # Number of tokens excluding padding.
122
- # (batch_size, max_blocks_per_seq).
123
- # Block addresses per sequence. (Seq id -> list of physical block)
124
- block_tables : torch .Tensor
125
- # (batch_size,). The sequence length per sequence. Sequence length means
126
- # the computed tokens + new tokens None if it is a decoding.
127
- query_start_loc : torch .Tensor
128
- query_lens : torch .Tensor
129
- seq_lens : torch .Tensor
130
- # Maximum query length in the batch. None for decoding.
131
- max_query_len : Optional [int ] = None
132
- # (num_tokens,). The indices of the token slots that input tokens will be
133
- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
134
- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
135
- # in block 0, and 1st slot in block 1, respectively.
136
- slot_mapping : torch .Tensor = None
137
- # Current state of this attention run.
138
- attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
139
- attn_mask : Optional [torch .Tensor ] = None
87
+ class AscendTorchairMetadata (AscendMetadata ):
140
88
141
89
decode : Optional [AscendDecodeMetadata ] = None
142
90
143
- enable_dbo_across_dp : bool = False
144
91
145
-
146
- class AscendAttentionTorchairMetadataBuilder :
92
+ class AscendAttentionTorchairMetadataBuilder (AscendAttentionMetadataBuilder ):
147
93
148
94
def __init__ (self , runner ):
149
- self .runner = runner
150
-
151
- def reorder_batch (self , input_batch : "InputBatch" ,
152
- scheduler_output : "SchedulerOutput" ) -> bool :
153
- return False
95
+ super ().__init__ (runner )
154
96
155
97
def _get_graph_runner_block_tables (
156
98
self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
@@ -222,11 +164,16 @@ def build(self,
222
164
num_reqs ,
223
165
num_actual_tokens ,
224
166
max_query_len ,
225
- graph_pad_size : int = - 1 ,
226
167
enable_dbo_across_dp : bool = False ,
168
+ is_only_prefill : bool = False ,
227
169
* args ,
228
170
** kwargs ):
229
171
172
+ if 'graph_pad_size' in kwargs :
173
+ graph_pad_size = kwargs ['graph_pad_size' ]
174
+ else :
175
+ graph_pad_size = - 1 # default value
176
+
230
177
device = self .runner .device
231
178
232
179
block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
0 commit comments