Skip to content

Commit 128c120

Browse files
[0.9.1][bugfix] Address abnormal VRAM increase in quantized models with floating-point MTP (#2554)
### **Problem & Cause** VRAM usage increased abnormally during mixed-precision inference with quantized models and floating-point MTP. This was caused by `dist.all_to_all_single` creating extra HCCL communicators, which produced unnecessary buffers that consumed more memory. ### **Solution** This commit adds a communicator parameter to `dist.all_to_all_single`. By passing the existing communicator from the `vllm-ascend` framework, we ensure all communication operations use a unified domain, preventing the creation of extra buffers and solving the VRAM issue. ### **Collaborators** @kunpengW-code cc @farawayboat @MengqingCao Signed-off-by: SlightwindSec <[email protected]>
1 parent 60c2df2 commit 128c120

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,9 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
526526

527527
gather_sizes = global_expert_tokens.new_empty(
528528
global_expert_tokens.shape[0])
529-
dist.all_to_all_single(gather_sizes, global_expert_tokens)
529+
dist.all_to_all_single(gather_sizes,
530+
global_expert_tokens,
531+
group=ep_group.device_group)
530532

531533
token_counts_combined = torch.stack(
532534
[gather_sizes, global_expert_tokens], dim=0)
@@ -542,10 +544,16 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
542544
gather_size_list = token_counts_combined_cpu[1]
543545
scatter_size_list = token_counts_combined_cpu[0]
544546

545-
dist.all_to_all_single(gathered_tokens, quantized_tokens,
546-
scatter_size_list, gather_size_list)
547-
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
548-
gather_size_list)
547+
dist.all_to_all_single(gathered_tokens,
548+
quantized_tokens,
549+
scatter_size_list,
550+
gather_size_list,
551+
group=ep_group.device_group)
552+
dist.all_to_all_single(dynamic_scale,
553+
token_scales,
554+
scatter_size_list,
555+
gather_size_list,
556+
group=ep_group.device_group)
549557

550558
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
551559
gathered_tokens,
@@ -593,8 +601,11 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
593601
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
594602

595603
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
596-
dist.all_to_all_single(hidden_states, reordered_outputs,
597-
gather_size_list, scatter_size_list)
604+
dist.all_to_all_single(hidden_states,
605+
reordered_outputs,
606+
gather_size_list,
607+
scatter_size_list,
608+
group=ep_group.device_group)
598609

599610
final_hidden_states = torch_npu.npu_moe_finalize_routing(
600611
hidden_states,

0 commit comments

Comments
 (0)