Skip to content

Commit 67aad1f

Browse files
authored
[BugFix][P/D] fix padding error on FullGraph mode && fix layerwise connector mamba accuracy (#7506)
### What this PR does / why we need it? 1. When the FullGraph mode is used, the branches in the Triton operator are compiled and fixed during the graph capture process, causing the branch condition in the `fused_recurrent_gated_delta_rule` operator, which checks whether `ssm_state_indices >= 0` before writing to the SSM cache, to become invalid. Now, the write operation is performed regardless of the value. This results in the operator performing address offset calculations and writing to the SSM cache based on the -1 offset after -1 is used for padding in vLLM GDN backend. Since the conv cache and SSM cache in vLLM Ascend implementation are actually a single continuous tensor divided into two parts, this leads to data overwriting and the generation of NaN values. This PR addresses two cases where padding -1 is required in the GDN metadata builder. The same logic is used to replace the padding with 0 to avoid the problem of memory overwriting, because block 0 is a reserved block. 2. Fix layerwise connector bug for mamba cache sending on heterogeneous TP. - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@8b63257 --------- Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
1 parent 475b4b0 commit 67aad1f

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta,
317317
for local_conv_offset, local_conv_size in zip(local_conv_offsets, local_conv_sizes):
318318
local_addr_offset = (i * conv_shape[1] + local_conv_offset) * get_dtype_size(conv_dtype)
319319
remote_addr_offset = (
320-
(i * conv_shape[1] * tp_ratio) + (self.tp_rank % tp_ratio) * local_conv_size
320+
(i * conv_shape[1] + local_conv_offset) * tp_ratio
321+
+ (self.tp_rank % tp_ratio) * local_conv_size
321322
) * get_dtype_size(conv_dtype)
322323
src_list.append(local_conv_addr + local_block_ids[0] * local_conv_len + local_addr_offset)
323324
dst_list.append(remote_conv_addr + remote_block_ids[0] * remote_conv_len + remote_addr_offset)
@@ -1508,9 +1509,9 @@ def save_kv_layer(
15081509
# get reshape and cache event
15091510
if layer_name == "":
15101511
layer_name = self.index_to_name[self.current_layer][0]
1511-
if (
1512-
type(attn_metadata) is dict and not getattr(attn_metadata[layer_name], "reshape_cache_event", None)
1513-
) or (not getattr(attn_metadata, "reshape_cache_event", None)):
1512+
if (self.use_mla and not hasattr(attn_metadata[layer_name], "reshape_cache_event")) or (
1513+
not self.use_mla and not hasattr(attn_metadata, "reshape_cache_event")
1514+
):
15141515
reshape_cache_event = torch.npu.Event()
15151516
reshape_cache_event.record()
15161517
elif self.use_mla:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,12 @@ def _build_attn_group_metadata(
21652165
common_attn_metadata=common_attn_metadata,
21662166
**extra_attn_metadata_args,
21672167
)
2168+
# NOTE(zxr): Due to the Triton operator does not deal with -1 padding in FullGraph mode,
2169+
# the padding needs to be changed from -1 to 0 to avoid writing invalid mamba block.
2170+
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() \
2171+
and isinstance(builder, GDNAttentionMetadataBuilder) and attn_metadata_i.num_prefills == 0:
2172+
if attn_metadata_i.num_decodes == 0 and attn_metadata_i.num_spec_decodes > 0:
2173+
attn_metadata_i.spec_state_indices_tensor[attn_metadata_i.num_spec_decodes:].fill_(0)
21682174

21692175
if ubid is None:
21702176
assert isinstance(attn_metadata, dict)

0 commit comments

Comments
 (0)