Skip to content

Commit f3b50c5

Browse files
[main][Prefill Perf] Optimize Quantized MoE Performance by Reducing All2All Communication (#2195)
This PR significantly optimizes performance for quantized Mixture of Experts (MoE) layers by changing the order of quantization and communication operations. In the previous implementation, the `all2all` operation was performed on unquantized `hidden_states` (in FP16/BF16) *before* quantization, resulting in substantial communication overhead. By performing quantization on each EP rank **first** and then sending the much smaller quantized data, we reduce the communication volume by nearly 50%. Additionally, this PR includes a minor optimization to cast `int` inputs to `float` for the `argsort` operation, forcing it to run on a faster NPU core instead of the AICPU. These changes lead to a clear and significant performance gain in MoE quantization scenarios. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@7175817 --------- Signed-off-by: SlightwindSec <[email protected]>
1 parent 292fb8f commit f3b50c5

File tree

2 files changed

+156
-48
lines changed

2 files changed

+156
-48
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import torch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
7+
8+
9+
class TestAscendW8A8FusedMoEMethod(TestBase):
10+
11+
def setUp(self):
12+
self.hidden_size = 128
13+
self.num_tokens = 128
14+
self.placeholder = torch.randn(self.num_tokens,
15+
self.hidden_size,
16+
dtype=torch.bfloat16)
17+
18+
@patch("torch.distributed.all_to_all_single")
19+
@patch("torch_npu.npu_moe_re_routing")
20+
@patch("torch_npu.npu_grouped_matmul")
21+
@patch("torch_npu.npu_swiglu")
22+
@patch("torch_npu.npu_dynamic_quant")
23+
@patch("torch_npu.npu_moe_finalize_routing")
24+
@patch("torch_npu.npu_moe_init_routing")
25+
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
26+
mock_moe_finalize_routing,
27+
mock_dynamic_quant, mock_swiglu,
28+
mock_grouped_matmul,
29+
mock_moe_re_routing,
30+
mock_all_to_all_single):
31+
expert_map = MagicMock()
32+
ep_group = MagicMock()
33+
placeholder_int8 = torch.randint(0,
34+
100,
35+
(self.num_tokens, self.hidden_size),
36+
dtype=torch.int8)
37+
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
38+
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
39+
input)
40+
mock_moe_init_routing.return_value = (
41+
placeholder_int8,
42+
placeholder_ones,
43+
placeholder_ones,
44+
)
45+
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
46+
torch.randint(0,
47+
100,
48+
(self.num_tokens, ),
49+
dtype=torch.int32),
50+
self.placeholder)
51+
mock_grouped_matmul.return_value = self.placeholder
52+
mock_swiglu.return_value = self.placeholder
53+
mock_dynamic_quant.return_value = (
54+
placeholder_int8,
55+
torch.randn(self.num_tokens),
56+
)
57+
mock_moe_finalize_routing.return_value = self.placeholder
58+
59+
result = fused_experts_with_all2all(
60+
hidden_states=self.placeholder,
61+
w1=self.placeholder,
62+
w1_scale=self.placeholder,
63+
w2=self.placeholder,
64+
w2_scale=self.placeholder,
65+
topk_weights=self.placeholder,
66+
topk_ids=self.placeholder,
67+
top_k=8,
68+
expert_map=expert_map,
69+
ep_group=ep_group,
70+
log2phy=None,
71+
global_redundant_expert_num=256,
72+
)
73+
self.assertIsNotNone(result)
74+
self.assertEqual(result.dtype, torch.bfloat16)
75+
self.assertEqual(result.shape, (128, 128))

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 81 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,29 @@ def fused_experts_with_mc2(
334334
return hidden_states, shared_output
335335

336336

337+
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
338+
num_tokens, _ = hidden_states.shape
339+
row_idx_len = num_tokens * top_k
340+
row_idx = (torch.arange(0,
341+
row_idx_len,
342+
dtype=torch.int32,
343+
device=hidden_states.device).view(
344+
top_k, -1).permute(1, 0).contiguous())
345+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
346+
hidden_states,
347+
row_idx=row_idx,
348+
expert_idx=topk_ids,
349+
active_num=num_tokens)
350+
351+
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
352+
1, 0).contiguous().view(-1))
353+
global_expert_tokens = torch.bincount(expanded_expert_idx,
354+
minlength=global_num_experts)
355+
global_expert_tokens = global_expert_tokens.to(torch.int32)
356+
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
357+
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales
358+
359+
337360
# currently expert parallelism implemented with all2all
338361
# is under-optimized.
339362
def fused_experts_with_all2all(
@@ -358,50 +381,54 @@ def fused_experts_with_all2all(
358381

359382
num_tokens, _ = hidden_states.shape
360383
num_experts = w1.shape[0]
361-
device = hidden_states.device
362384

363385
if expert_map is not None:
364386
global_num_experts = len(expert_map) + global_redundant_expert_num
365-
local_num_experts = global_num_experts // ep_group.world_size
366-
row_idx_len = num_tokens * top_k
367-
row_idx = (torch.arange(0,
368-
row_idx_len,
369-
dtype=torch.int32,
370-
device=device).view(top_k, -1).permute(
371-
1, 0).contiguous())
372-
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
373-
hidden_states,
374-
row_idx=row_idx,
375-
expert_idx=topk_ids,
376-
active_num=num_tokens)
377-
378-
global_expert_tokens = torch.bincount(expanded_expert_idx,
379-
minlength=global_num_experts)
380-
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
381-
-1).sum(-1)
382-
383-
gather_sizes = torch.empty_like(scatter_sizes)
384-
dist.all_to_all_single(gather_sizes,
385-
scatter_sizes,
386-
group=ep_group.device_group)
387-
scatter_size_list = scatter_sizes.cpu().tolist()
388-
gather_size_list = gather_sizes.cpu().tolist()
389-
390-
expanded_expert_idx = expanded_expert_idx % local_num_experts
391-
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
392-
scatter_size_list,
393-
gather_size_list)
394-
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
395-
scatter_size_list,
396-
gather_size_list)
397-
398-
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
399-
400-
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
401-
sorted_local_expert_idx, local_num_experts).to(torch.int64)
402-
403-
hidden_states = hidden_states[sorted_idx]
404-
group_list_type = 0
387+
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
388+
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
389+
hidden_states,
390+
expert_idx=topk_ids.to(torch.int32),
391+
active_num=0,
392+
expert_capacity=0,
393+
expert_num=global_num_experts,
394+
drop_pad_mode=0,
395+
expert_tokens_num_mode=2,
396+
expert_tokens_before_capacity_flag=False,
397+
quant_mode=1,
398+
)
399+
else:
400+
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
401+
hidden_states, top_k, topk_ids, global_num_experts)
402+
403+
gather_sizes = global_expert_tokens.new_empty(
404+
global_expert_tokens.shape[0])
405+
dist.all_to_all_single(gather_sizes, global_expert_tokens)
406+
407+
token_counts_combined = torch.stack(
408+
[gather_sizes, global_expert_tokens], dim=0)
409+
token_counts_combined = token_counts_combined.view(
410+
2, ep_group.world_size, -1).sum(dim=2)
411+
token_counts_combined_cpu = token_counts_combined.to(
412+
torch.device("cpu"), non_blocking=True).numpy()
413+
all_tokens = gather_sizes.sum()
414+
415+
gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
416+
quantized_tokens.shape[1])
417+
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
418+
gather_size_list = token_counts_combined_cpu[1]
419+
scatter_size_list = token_counts_combined_cpu[0]
420+
421+
dist.all_to_all_single(gathered_tokens, quantized_tokens,
422+
scatter_size_list, gather_size_list)
423+
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
424+
gather_size_list)
425+
426+
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
427+
gathered_tokens,
428+
gather_sizes.view(ep_group.world_size, -1),
429+
per_token_scales=dynamic_scale)
430+
expert_tokens = expert_tokens.to(torch.int64)
431+
group_list_type = 1
405432
else:
406433
row_idx_len = num_tokens * top_k
407434
row_idx = torch.arange(0,
@@ -419,6 +446,7 @@ def fused_experts_with_all2all(
419446
expanded_expert_idx, num_experts)
420447
expert_tokens = expert_tokens.to(torch.int64)
421448
group_list_type = 0
449+
dynamic_scale = None
422450

423451
# `hidden_states` will be disposed in the `apply_mlp` function
424452
hidden_states = apply_mlp(
@@ -428,14 +456,19 @@ def fused_experts_with_all2all(
428456
w2,
429457
w2_scale,
430458
expert_tokens, #16
459+
dynamic_scale=dynamic_scale,
431460
group_list_type=group_list_type)
432461

433462
if expert_map is not None:
434-
resorted_idx = torch.argsort(sorted_idx)
435-
hidden_states = hidden_states[resorted_idx]
436-
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
437-
gather_size_list,
438-
scatter_size_list)
463+
reordered_outputs = torch.index_select(
464+
hidden_states,
465+
dim=0,
466+
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
467+
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
468+
469+
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
470+
dist.all_to_all_single(hidden_states, reordered_outputs,
471+
gather_size_list, scatter_size_list)
439472

440473
final_hidden_states = torch_npu.npu_moe_finalize_routing(
441474
hidden_states,
@@ -444,8 +477,8 @@ def fused_experts_with_all2all(
444477
bias=None,
445478
scales=topk_weights,
446479
expanded_src_to_dst_row=expanded_row_idx,
447-
export_for_source_row=topk_ids,
448-
)
480+
export_for_source_row=None,
481+
drop_pad_mode=2)
449482
else:
450483
# TODO: Reorder device memory 2 times here, replace the current
451484
# implementation here when suitable operators become available.

0 commit comments

Comments
 (0)