Skip to content

Commit 4641ee2

Browse files
committed
Refactor mla
1 parent 6d9e5f6 commit 4641ee2

File tree

5 files changed

+295
-182
lines changed

5 files changed

+295
-182
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import torch
2323
import torch_npu
24+
import torch.nn as nn
25+
from vllm.config import VllmConfig
2426
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2527
AttentionLayer, AttentionType)
2628
from vllm.attention.backends.utils import CommonAttentionState
@@ -35,6 +37,7 @@
3537
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3638
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3739
from vllm_ascend.utils import get_graph_params
40+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3841

3942

4043
class AscendAttentionBackend(AttentionBackend):
@@ -156,39 +159,48 @@ def split_metadata_for_multistream(
156159

157160
class AscendAttentionMetadataBuilder:
158161

159-
def __init__(self, runner):
162+
def __init__(
163+
self,
164+
vllm_config: VllmConfig,
165+
device: torch.device,
166+
runner
167+
):
168+
self.vllm_config = vllm_config
169+
self.model_config = vllm_config.model_config
170+
self.device = device
160171
self.runner = runner
161172

162173
def reorder_batch(self, input_batch: "InputBatch",
163174
scheduler_output: "SchedulerOutput") -> bool:
164175
return False
165176

166-
def build(self,
167-
num_reqs,
168-
num_actual_tokens,
169-
max_query_len,
170-
common_attn_metadata: CommonAttentionMetadata,
171-
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False,
173-
*args,
174-
**kwargs):
175-
176-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
177-
)
178-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
179-
block_table[:num_reqs])
180-
181-
query_start_loc = common_attn_metadata.query_start_loc
182-
seq_lens = common_attn_metadata.seq_lens
177+
def build(
178+
self,
179+
common_attn_metadata: AscendCommonAttentionMetadata,
180+
):
181+
num_reqs = common_attn_metadata.num_reqs
182+
num_actual_tokens = common_attn_metadata.num_actual_tokens
183+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
184+
num_reqs
185+
+ 1]
186+
187+
block_table = common_attn_metadata.block_table_tensor
188+
block_table[:num_reqs, :common_attn_metadata.
189+
max_num_blocks_per_req] = (block_table[:num_reqs])
190+
191+
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
183192
# TODO: Refactor these two param to common metadata in runners,
184193
# preparing for the hybrid KV groups feature
185-
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
194+
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
186195
# Since FIA for GQA is not active now, we temporarily silence it
187196
seq_lens_list = common_attn_metadata.seq_lens_list
188197

189-
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
190-
attn_mask = self.runner.attn_mask
191-
attn_state = self.runner.attn_state
198+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to(
199+
self.device, non_blocking=True)
200+
attn_mask = common_attn_metadata.attn_mask
201+
attn_state = common_attn_metadata.attn_state
202+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
203+
query_start_loc = query_start_loc_cpu.to(self.device,non_blocking=True)
192204

193205
attn_metadata = AscendMetadata(
194206
num_actual_tokens=num_actual_tokens,
@@ -197,12 +209,12 @@ def build(self,
197209
query_lens=query_lens,
198210
seq_lens=seq_lens,
199211
seq_lens_list=seq_lens_list,
200-
max_query_len=max_query_len,
212+
max_query_len=common_attn_metadata.max_query_len,
201213
slot_mapping=slot_mapping,
202214
attn_mask=attn_mask,
203215
attn_state=attn_state,
204-
enable_dbo_across_dp=enable_dbo_across_dp,
205-
is_only_prefill=is_only_prefill)
216+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
217+
is_only_prefill=common_attn_metadata.is_only_prefill)
206218
return attn_metadata
207219

208220
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
@@ -217,14 +229,33 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
217229
num_reqs + 1,
218230
device=block_table.device,
219231
dtype=block_table.dtype)
232+
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
233+
)
234+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
235+
block_table[:num_reqs])
236+
237+
query_start_loc = common_attn_metadata.query_start_loc
238+
seq_lens = common_attn_metadata.seq_lens
239+
query_lens = self.runner.query_lens
240+
seq_lens_list = None
220241

221-
attn_metadata = self.build(
222-
num_reqs=num_reqs,
242+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
243+
attn_mask = self.runner.attn_mask
244+
attn_state = self.runner.attn_state
245+
246+
attn_metadata = AscendMetadata(
223247
num_actual_tokens=num_actual_tokens,
248+
block_tables=block_table,
249+
query_start_loc=query_start_loc,
250+
query_lens=query_lens,
251+
seq_lens=seq_lens,
252+
seq_lens_list=seq_lens_list,
224253
max_query_len=num_scheduled_tokens.max(),
225-
common_prefix_len=0,
226-
common_attn_metadata=common_attn_metadata,
227-
)
254+
slot_mapping=slot_mapping,
255+
attn_mask=attn_mask,
256+
attn_state=attn_state,
257+
enable_dbo_across_dp=False,
258+
is_only_prefill=False)
228259
else:
229260
raise NotImplementedError(
230261
"Currently we only support building dummy metadata for DecodeOnly state"

0 commit comments

Comments
 (0)