Skip to content

Commit 9dc23b6

Browse files
authored
[V0.9.1][BugFix] Fix bugs and refactor cached mask generation logic (#2326)
### What this PR does / why we need it? This PR fix bugs and refactor cached mask generation logic. Now just pre-construct and use the cached mask on cpu instead of device on npu. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. Signed-off-by: rjg-lyh <[email protected]>
1 parent 6d9e5f6 commit 9dc23b6

File tree

2 files changed

+38
-68
lines changed

2 files changed

+38
-68
lines changed

vllm_ascend/attention/attention.py

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class AttentionMaskBuilder:
6565
def __init__(self, attn_mask: torch.Tensor):
6666
self._seq_len_cached = attn_mask.shape[0]
6767
self.attn_mask_cache = attn_mask
68-
self.splitfuse_mask_value = -10000
6968

7069
@classmethod
7170
def initialize_from_len(cls,
@@ -74,18 +73,25 @@ def initialize_from_len(cls,
7473
mask_value: Optional[int] = None):
7574
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
7675

77-
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
78-
device: torch.device):
79-
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
76+
@staticmethod
77+
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
78+
mask_scale_factor = 1
79+
if dtype == torch.bfloat16:
80+
mask_scale_factor = -10000
81+
return mask_scale_factor
82+
83+
def update_attn_cache(self, seqlen: int, dtype: torch.dtype):
84+
if seqlen > self._seq_len_cached:
8085
self._seq_len_cached = seqlen
8186
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
82-
if self.attn_mask_cache.device != device:
83-
self.attn_mask_cache = self.attn_mask_cache.to(device)
87+
if self.attn_mask_cache.dtype != dtype:
88+
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
8489

8590
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
8691
device: torch.device):
87-
self.update_attn_cache(max_seq_len, dtype, device)
88-
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
92+
self.update_attn_cache(max_seq_len, dtype)
93+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
94+
).to(device)
8995

9096
def get_decode_attn_mask(
9197
self,
@@ -94,53 +100,28 @@ def get_decode_attn_mask(
94100
dtype: torch.dtype,
95101
device: torch.device,
96102
):
97-
self.update_attn_cache(max_s, dtype, device)
103+
self.update_attn_cache(max_s, dtype)
98104
return (self.attn_mask_cache.index_select(
99-
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
105+
0, input_lengths)[:, :max_s].view(-1, 1,
106+
max_s).contiguous().to(device))
100107

101108
def get_splitfuse_attn_mask(
102109
self,
103110
seq_lens,
104-
query_lens,
105111
position,
106112
dtype,
107113
device,
108114
) -> torch.Tensor:
109115
max_seq_len = max(seq_lens, default=0)
110-
if max_seq_len <= self._seq_len_cached:
111-
self.update_attn_cache(max_seq_len, dtype, device)
112-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
113-
# is not the same. Fix this in the future when kernel is ready.
114-
if self.attn_mask_cache.numel(
115-
) > 1 and self.attn_mask_cache[0][1] > 0:
116-
attn_mask = self.get_attn_mask( # type: ignore
117-
max_seq_len, dtype, device)
118-
attn_mask *= -10000
119-
else:
120-
attn_mask = self.attn_mask_cache
121-
return torch.index_select(attn_mask, dim=0,
122-
index=position)[:, :max_seq_len]
123-
total_q_len = sum(query_lens)
124-
attn_mask = torch.zeros((total_q_len, max_seq_len),
125-
dtype=dtype,
126-
device="cpu")
127-
128-
current_row = 0
129-
for i in range(len(query_lens)):
130-
seq_len = seq_lens[i]
131-
q_len = query_lens[i]
132-
context_len = seq_len - q_len
133-
134-
assert context_len >= 0
135-
attn_mask[current_row:current_row + q_len,
136-
context_len:] = self.splitfuse_mask_value
137-
right_tensor = attn_mask[current_row:current_row + q_len,
138-
context_len:seq_len]
139-
right_tensor.masked_fill_(
140-
right_tensor.tril() == self.splitfuse_mask_value, 0)
141-
current_row += q_len
142-
143-
return attn_mask.to(device, non_blocking=True)
116+
self.update_attn_cache(max_seq_len, dtype)
117+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
118+
# is not the same. Fix this in the future when kernel is ready.
119+
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
120+
attn_mask = torch.index_select(self.attn_mask_cache,
121+
dim=0,
122+
index=position)[:, :max_seq_len]
123+
attn_mask *= mask_scale_factor
124+
return attn_mask.contiguous().to(device, non_blocking=True)
144125

145126

146127
class AscendAttentionBackend(AttentionBackend):

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import copy
2121
import gc
2222
import math
23-
import os
2423
import time
2524
import types
2625
import weakref
@@ -349,19 +348,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
349348
reversed(
350349
self.vllm_config.compilation_config.cudagraph_capture_sizes))
351350

352-
# NOTE: Pre-construct a mask matrix to improve the efficiency of
351+
# NOTE: Pre-construct a mask matrix on cpu to improve the efficiency of
353352
# attention mask construction during inference.
354-
# Note that the length of the matrix needs to be carefully balanced: a
355-
# matrix that is too large will consume excessive VRAM, while a matrix
356-
# that is too small will require dynamic concatenation during inference,
357-
# leading to performance degradation.
358-
# Therefore, an environment variable is added here to dynamically set
359-
# the size of the pre-constructed mask matrix based on requirements.
360-
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
361-
self.attn_mask_len = min(self.model_config.max_model_len,
362-
int(mask_len))
363353
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
364-
self.attn_mask_len, self.dtype)
354+
self.model_config.max_model_len, self.dtype)
365355

366356
self.sampler = Sampler()
367357
self.new_kv_cache_bytes = -1
@@ -703,12 +693,12 @@ def _check_dbo_is_valid(self, query_lens: torch.Tensor,
703693
def get_model(self) -> nn.Module:
704694
return self.model
705695

706-
def _make_attention_mask(self, seq_lens, query_lens, position,
696+
def _make_attention_mask(self, seq_lens, position,
707697
attn_state) -> torch.Tensor:
708698
# Chunk Prefill situation.
709699
if attn_state == AscendAttentionState.ChunkedPrefill:
710700
return self.attn_mask_builder.get_splitfuse_attn_mask(
711-
seq_lens, query_lens, position, self.dtype, self.device)
701+
seq_lens, position, self.dtype, self.device)
712702
# Prefill without cache situation.
713703
elif attn_state == AscendAttentionState.PrefillNoCache:
714704
max_seq_len = max(seq_lens, default=0)
@@ -956,16 +946,17 @@ def _process_reqs(
956946
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
957947
non_blocking=True)
958948

959-
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
960-
self.positions[:total_num_scheduled_tokens].copy_(
961-
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
949+
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
950+
self.positions[:num_input_tokens].copy_(
951+
self.positions_cpu[:num_input_tokens], non_blocking=True)
952+
positions_cpu = self.positions_cpu[:num_input_tokens]
962953
positions = self.positions[:num_input_tokens]
963954
self.query_lens = torch.from_numpy(num_scheduled_tokens)
964955

965956
self.seq_lens_np[:num_reqs] = (
966957
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
967958
num_scheduled_tokens)
968-
seq_lens = self.seq_lens_cpu[:num_reqs]
959+
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
969960

970961
block_table_indices = (req_indices * self.max_num_blocks_per_req +
971962
positions_np // self.block_size)
@@ -999,11 +990,9 @@ def _process_reqs(
999990

1000991
# NOTE: when use ring_mla, attn_mask don't need to generate here.
1001992
if not self.vllm_config.model_config.use_mla:
1002-
attn_mask = self._make_attention_mask(
1003-
seq_lens=seq_lens,
1004-
query_lens=num_scheduled_tokens,
1005-
position=positions,
1006-
attn_state=attn_state)
993+
attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
994+
position=positions_cpu,
995+
attn_state=attn_state)
1007996
self.attn_mask = attn_mask
1008997
self.attn_state = attn_state # type: ignore
1009998

0 commit comments

Comments
 (0)