Skip to content

Commit 2069bef

Browse files
authored
[v0.11.0-dev][bugfix] Fix a bug in wrongly set npu_stream (#4106)
### What this PR does / why we need it? This pr fixes a bug introduced in #3985, which set wrong npu_stream (possibly by mistakes in cherry-pick). I correct it and make `update_attn_params` consistent to main branch. ### Does this PR introduce _any_ user-facing change? No. Signed-off-by: Angazenn <[email protected]>
1 parent c5fe179 commit 2069bef

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -213,26 +213,24 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
213213
) = param
214214
seq_lens = forward_context.attn_metadata[key].seq_lens
215215

216-
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
217-
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
218-
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
219-
# might encounter a bigger workspace, while currently we use max_model_len to
220-
# calculate max workspace in capturing. So additional get_workspace is added
221-
# here to avoid such bugs.
222-
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
223-
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
224-
workspace = torch_npu._npu_paged_attention_get_workspace(
225-
query=query,
226-
key_cache=key_cache,
227-
value_cache=value_cache,
228-
num_kv_heads=num_kv_heads,
229-
num_heads=num_heads,
230-
scale_value=scale,
231-
block_table=block_table,
232-
context_lens=seq_lens,
233-
out=output)
234-
235-
with torch.npu.stream(update_stream):
216+
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
217+
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
218+
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
219+
# might encounter a bigger workspace, while currently we use max_model_len to
220+
# calculate max workspace in capturing. So additional get_workspace is added
221+
# here to avoid such bugs.
222+
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
223+
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
224+
workspace = torch_npu._npu_paged_attention_get_workspace(
225+
query=query,
226+
key_cache=key_cache,
227+
value_cache=value_cache,
228+
num_kv_heads=num_kv_heads,
229+
num_heads=num_heads,
230+
scale_value=scale,
231+
block_table=block_table,
232+
context_lens=seq_lens,
233+
out=output)
236234
torch.npu.graph_task_update_begin(update_stream, handle)
237235
torch_npu._npu_paged_attention(query=query,
238236
key_cache=key_cache,

0 commit comments

Comments
 (0)