Skip to content

Commit 55d0790

Browse files
[2/N][Refactor] Refactor V1 attention for better extensibility (#1995)
### What this PR does / why we need it? Refactor V1 Attention for better extensibility (prepared for torchair attention refactor). **Main changes:** - Move different kinds of foward into their method respectively, e.g., `_forward_prefill_no_cache()`, `_forward_prefill_cache_hit()`, `_forward_decode_only()`, `_forward_v1_style()`. ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@14a5d90 Signed-off-by: shen-shanshan <[email protected]>
1 parent 8914d5a commit 55d0790

File tree

1 file changed

+150
-102
lines changed

1 file changed

+150
-102
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 150 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class AscendAttentionState(Enum):
120120
@dataclass
121121
class AscendMetadata:
122122

123-
# **************************** Basic Properties ****************************
123+
# **************************** Basic Properties ************************** #
124124
attn_mask: Optional[torch.Tensor] = None
125125
# Current state of this attention run.
126126
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -138,7 +138,7 @@ class AscendMetadata:
138138
# Maximum query length in the batch (None for decoding).
139139
max_query_len: Optional[int] = None
140140

141-
# ********************** KV Cache Related Properties ***********************
141+
# ********************** KV Cache Related Properties ********************* #
142142
# Block addresses per sequence (Seq id -> list of physical block).
143143
# (batch_size, max_blocks_per_seq)
144144
block_tables: torch.Tensor = None
@@ -150,6 +150,7 @@ class AscendMetadata:
150150
# (num_tokens,)
151151
slot_mapping: torch.Tensor = None
152152

153+
# *************************** Other Properties *************************** #
153154
enable_dbo_across_dp: bool = False
154155
is_only_prefill: bool = False
155156

@@ -245,6 +246,144 @@ def __init__(
245246
self.key_cache = None
246247
self.value_cache = None
247248

249+
def _forward_prefill_no_cache(
250+
self,
251+
query: torch.Tensor,
252+
key: torch.Tensor,
253+
value: torch.Tensor,
254+
attn_metadata: AscendMetadata,
255+
output: Optional[torch.Tensor] = None,
256+
num_tokens=0,
257+
) -> torch.Tensor:
258+
assert attn_metadata is not None
259+
assert attn_metadata.attn_mask is not None
260+
261+
mask = attn_metadata.attn_mask
262+
263+
if is_310p():
264+
# align q k v output tensors
265+
query = aligned_16(query)
266+
key = aligned_16(key)
267+
value = aligned_16(value)
268+
output = aligned_16(output)
269+
# do reformat in case of broadcasted tensors
270+
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
271+
mask = torch_npu.npu_format_cast(mask.contiguous(),
272+
ACL_FORMAT_FRACTAL_NZ)
273+
274+
torch_npu._npu_flash_attention(query=query,
275+
key=key,
276+
value=value,
277+
mask=mask,
278+
seq_len=attn_metadata.seq_lens,
279+
scale_value=self.scale,
280+
num_heads=self.num_heads,
281+
num_kv_heads=self.num_kv_heads,
282+
out=output)
283+
assert output is not None
284+
return output[:num_tokens, :, :]
285+
286+
def _forward_prefill_cache_hit(
287+
self,
288+
query: torch.Tensor,
289+
attn_metadata: AscendMetadata,
290+
output: Optional[torch.Tensor] = None,
291+
) -> torch.Tensor:
292+
assert attn_metadata is not None
293+
assert attn_metadata.attn_mask is not None
294+
295+
compress_mask = attn_metadata.attn_mask
296+
batch_size = attn_metadata.query_lens.shape[0]
297+
block_table = attn_metadata.block_tables[:batch_size, :]
298+
299+
torch_npu._npu_flash_attention_qlens(
300+
query=query,
301+
key_cache=self.key_cache,
302+
value_cache=self.value_cache,
303+
block_table=block_table,
304+
mask=compress_mask,
305+
seq_len=attn_metadata.query_lens,
306+
context_lens=attn_metadata.seq_lens,
307+
num_kv_heads=self.num_kv_heads,
308+
num_heads=self.num_heads,
309+
scale_value=self.scale,
310+
out=output)
311+
return output
312+
313+
def _forward_decode_only(
314+
self,
315+
query: torch.Tensor,
316+
attn_metadata: AscendMetadata,
317+
output: Optional[torch.Tensor] = None,
318+
) -> torch.Tensor:
319+
if is_310p():
320+
# seq_lens_tensor needs to be transferred to the device for 310P.
321+
attn_metadata.seq_lens = \
322+
attn_metadata.seq_lens.to(device=query.device)
323+
324+
torch_npu._npu_paged_attention(query=query,
325+
key_cache=self.key_cache,
326+
value_cache=self.value_cache,
327+
num_kv_heads=self.num_kv_heads,
328+
num_heads=self.num_heads,
329+
scale_value=self.scale,
330+
block_table=attn_metadata.block_tables,
331+
context_lens=attn_metadata.seq_lens,
332+
out=output)
333+
return output
334+
335+
def _forward_v1_style(
336+
self,
337+
query: torch.Tensor,
338+
attn_metadata: AscendMetadata,
339+
output: Optional[torch.Tensor] = None,
340+
) -> torch.Tensor:
341+
# Use chunked prefill for head size 192 scenario, like deepseek
342+
# paged_attention_splitfuse maybe crash at such scenario.
343+
# TODO: vanilla path will be removed after the kernel support
344+
# head_size 192 scenario.
345+
if self.head_size == 192:
346+
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
347+
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
348+
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
349+
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
350+
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
351+
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
352+
max_seqlen_q = torch.max(attn_metadata.query_lens)
353+
max_seqlen_k = torch.max(attn_metadata.seq_lens)
354+
vanilla_chunked_prefill(output, query, self.key_cache,
355+
self.value_cache,
356+
attn_metadata.block_tables, cu_seqlen_q,
357+
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
358+
self.scale, None, True)
359+
return output
360+
361+
# Use paged attention.
362+
assert attn_metadata is not None
363+
assert attn_metadata.attn_mask is not None
364+
365+
if is_310p():
366+
# Do reformat in case of broadcasted tensors.
367+
attn_metadata.attn_mask = \
368+
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
369+
ACL_FORMAT_FRACTAL_NZ)
370+
attn_metadata.seq_lens = \
371+
attn_metadata.seq_lens.to(device=query.device)
372+
373+
torch_npu._npu_paged_attention_splitfuse(
374+
query=query,
375+
key_cache=self.key_cache,
376+
value_cache=self.value_cache,
377+
mask=attn_metadata.attn_mask,
378+
block_table=attn_metadata.block_tables,
379+
seq_len=attn_metadata.query_lens,
380+
context_lens=attn_metadata.seq_lens,
381+
num_kv_heads=self.num_kv_heads,
382+
num_heads=self.num_heads,
383+
scale_value=self.scale,
384+
out=output)
385+
return output
386+
248387
def forward(
249388
self,
250389
layer: AttentionLayer,
@@ -325,109 +464,18 @@ def forward(
325464

326465
# V0-Style scheduler situation.
327466
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
328-
assert attn_metadata is not None
329-
assert attn_metadata.attn_mask is not None
330-
mask = attn_metadata.attn_mask
331-
if is_310p():
332-
# align q k v output tensors
333-
query = aligned_16(query)
334-
key = aligned_16(key)
335-
value = aligned_16(value)
336-
output = aligned_16(output)
337-
338-
# do reformat in case of broadcasted tensors
339-
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
340-
mask = torch_npu.npu_format_cast(mask.contiguous(),
341-
ACL_FORMAT_FRACTAL_NZ)
342-
343-
torch_npu._npu_flash_attention(query=query,
344-
key=key,
345-
value=value,
346-
mask=mask,
347-
seq_len=attn_metadata.seq_lens,
348-
scale_value=self.scale,
349-
num_heads=self.num_heads,
350-
num_kv_heads=self.num_kv_heads,
351-
out=output)
352-
output = output[:num_tokens, :, :]
353-
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
354-
assert attn_metadata is not None
355-
assert attn_metadata.attn_mask is not None
356-
compress_mask = attn_metadata.attn_mask
357-
batch_size = attn_metadata.query_lens.shape[0]
358-
block_table = attn_metadata.block_tables[:batch_size, :]
359-
torch_npu._npu_flash_attention_qlens(
360-
query=query,
361-
key_cache=self.key_cache,
362-
value_cache=self.value_cache,
363-
block_table=block_table,
364-
mask=compress_mask,
365-
seq_len=attn_metadata.query_lens,
366-
context_lens=attn_metadata.seq_lens,
367-
num_kv_heads=self.num_kv_heads,
368-
num_heads=self.num_heads,
369-
scale_value=self.scale,
370-
out=output)
467+
output = self._forward_prefill_no_cache(
468+
query, key, value, attn_metadata, output, num_tokens)
469+
elif attn_metadata.attn_state == \
470+
AscendAttentionState.PrefillCacheHit:
471+
output = self._forward_prefill_cache_hit(
472+
query, attn_metadata, output)
371473
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
372-
if is_310p():
373-
# # seq_lens_tensor needs to be transferred to the device for 310P
374-
attn_metadata.seq_lens = \
375-
attn_metadata.seq_lens.to(device=query.device)
376-
torch_npu._npu_paged_attention(
377-
query=query,
378-
key_cache=self.key_cache,
379-
value_cache=self.value_cache,
380-
num_kv_heads=self.num_kv_heads,
381-
num_heads=self.num_heads,
382-
scale_value=self.scale,
383-
block_table=attn_metadata.block_tables,
384-
context_lens=attn_metadata.seq_lens,
385-
out=output)
474+
output = self._forward_decode_only(query, attn_metadata,
475+
output)
386476
# Normal V1 situation.
387477
else:
388-
# use chunked prefill for head size 192 scenario, like deepseek
389-
# paged_attention_splitfuse maybe crash at such scenario
390-
# TODO: vanilla path will be removed after the kernel support
391-
# head_size 192 scenario
392-
if self.head_size == 192:
393-
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
394-
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
395-
cu_seqlen_q = torch.tensor(cu_seqlen_q,
396-
device=query.device)
397-
cu_seqlen_k = torch.tensor(cu_seqlen_k,
398-
device=query.device)
399-
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
400-
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
401-
max_seqlen_q = torch.max(attn_metadata.query_lens)
402-
max_seqlen_k = torch.max(attn_metadata.seq_lens)
403-
vanilla_chunked_prefill(output, query, self.key_cache,
404-
self.value_cache,
405-
attn_metadata.block_tables,
406-
cu_seqlen_q, cu_seqlen_k,
407-
max_seqlen_q, max_seqlen_k,
408-
self.scale, None, True)
409-
else:
410-
# use paged attention
411-
assert attn_metadata is not None
412-
assert attn_metadata.attn_mask is not None
413-
if is_310p():
414-
# do reformat in case of broadcasted tensors
415-
attn_metadata.attn_mask = \
416-
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
417-
attn_metadata.seq_lens = \
418-
attn_metadata.seq_lens.to(device=query.device)
419-
torch_npu._npu_paged_attention_splitfuse(
420-
query=query,
421-
key_cache=self.key_cache,
422-
value_cache=self.value_cache,
423-
mask=attn_metadata.attn_mask,
424-
block_table=attn_metadata.block_tables,
425-
seq_len=attn_metadata.query_lens,
426-
context_lens=attn_metadata.seq_lens,
427-
num_kv_heads=self.num_kv_heads,
428-
num_heads=self.num_heads,
429-
scale_value=self.scale,
430-
out=output)
478+
output = self._forward_v1_style(query, attn_metadata, output)
431479

432480
# to make in-place change to the output tensor
433481
if hasattr(layer, 'quant_method') and use_kv_cache_int8:

0 commit comments

Comments
 (0)