Skip to content

Commit a5ca6a5

Browse files
authored
[0.9.1][BUGFIX] FIX FIA input when mtp is enabled in pd Disaggregation scenario (#2509)
### What this PR does / why we need it? This bug can be triggered by receving over 16 requests at one time from prefill node for one decode node, since torch_npu.npu_fused_infer_attention_score can only accept 16 sequence length for query in one batch. ### How was this patch tested? 4P1D: P: ``` vllm serve /mnt/nfs/levis/DeepSeek-R1_w8a8_vllm \ --host 0.0.0.0 \ --port 20002 \ --data-parallel-size 2 \ --data-parallel-size-local 2 \ --data-parallel-address 141.61.39.149 \ --data-parallel-rpc-port 13348 \ --tensor-parallel-size 8 \ --max-num-seqs 512 \ --seed 1024 \ --served-model-name ds_r1 \ --max-model-len 17000 \ --max-num-batched-tokens 16384 \ --trust-remote-code \ --gpu-memory-utilization 0.9 \ --quantization ascend \ --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ --enable-expert-parallel \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector": "LLMDataDistCMgrConnector", "kv_buffer_device": "npu", "kv_role": "kv_producer", "kv_parallel_size": 1, "kv_port": "20001", "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' \ --additional-config \ '{"ascend_scheduler_config":{"enabled":false}, "torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true,"enable_prefill_optimizations":true}' ``` D: ``` vllm serve /mnt/nfs/levis/DeepSeek-R1_w8a8_vllm \ --host 0.0.0.0 \ --port 20002 \ --data-parallel-size 64 \ --data-parallel-size-local 16 \ --data-parallel-address 141.61.39.165 \ --data-parallel-rpc-port 13348 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name ds_r1 \ --max-model-len 17000 \ --max-num-batched-tokens 256 \ --max-num-seqs 28 \ --quantization ascend \ --trust-remote-code \ --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ --gpu-memory-utilization 0.9 \ --enable-expert-parallel \ --kv-transfer-config \ '{"kv_connector": "LLMDataDistCMgrConnector", "kv_buffer_device": "npu", "kv_role": "kv_consumer", "kv_parallel_size": 1, "kv_port": "20001", "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' \ --additional-config \ '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"enable_multistream_moe":true,"graph_batch_sizes":[28], "enable_super_kernel":true, "use_cached_graph":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' ``` Signed-off-by: xuyexiong <[email protected]>
1 parent 8aadcb7 commit a5ca6a5

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,9 @@ def build(
543543
device=input_positions.device)
544544
input_positions = torch.cat(
545545
[input_positions, position_padding])
546-
actual_seq_lengths_q = actual_seq_lengths_q + self.runner.actual_seq_lengths_q[
547-
num_reqs:num_reqs + num_reqs_pad_size]
546+
547+
actual_seq_lengths_q = self.pad_actual_seq_len_q(
548+
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
548549
else:
549550
seq_lens_list = seq_lens.tolist()
550551
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
@@ -588,6 +589,30 @@ def build(
588589
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
589590
)
590591

592+
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
593+
actual_seq_lengths_q):
594+
need_padding = num_reqs_pad_size != 0 and \
595+
len(self.runner.actual_seq_lengths_q) > num_reqs and \
596+
self.runner.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > 16
597+
if need_padding:
598+
padding_seq_len_q = self.runner.actual_seq_lengths_q[
599+
num_reqs:num_reqs + num_reqs_pad_size]
600+
start_val = actual_seq_lengths_q[-1]
601+
end_val = padding_seq_len_q[-1]
602+
603+
num_step = len(padding_seq_len_q)
604+
interpolated = np.round(
605+
np.linspace(start_val, end_val,
606+
num_step + 1)[1:]).astype(int).tolist()
607+
assert interpolated[-1] == end_val
608+
assert len(interpolated) == len(padding_seq_len_q)
609+
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
610+
else:
611+
actual_seq_lengths_q = actual_seq_lengths_q + self.runner.actual_seq_lengths_q[
612+
num_reqs:num_reqs + num_reqs_pad_size]
613+
614+
return actual_seq_lengths_q
615+
591616

592617
class AscendMLAImpl(MLAAttentionImpl):
593618
"""

0 commit comments

Comments
 (0)