Skip to content

Commit b2c8ce5

Browse files
authored
Fix Flashinfer CUTLASS MOE Allgather (#21963)
Signed-off-by: Shu Wang <[email protected]>
1 parent a3b9c17 commit b2c8ce5

File tree

4 files changed

+71
-27
lines changed

4 files changed

+71
-27
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def _all_gather_single(input_: torch.Tensor,
236236
input_size = input_.size()
237237
if sizes is not None:
238238
assert len(sizes) == world_size
239-
assert input_.shape[dim] == sizes[self.rank_in_group]
239+
assert input_.shape[dim] == sizes[self.rank_in_group], (
240+
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
240241
output_size = (sum(sizes), ) + input_size[1:]
241242
else:
242243
output_size = (input_size[0] * world_size, ) + input_size[1:]

vllm/forward_context.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,26 @@
2626
batchsize_forward_time: defaultdict = defaultdict(list)
2727

2828

29+
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
30+
max_num_tokens: int,
31+
chunk_idx: int) -> list[int]:
32+
dp_size = len(num_tokens_across_dp_cpu)
33+
34+
local_size = [-1] * dp_size
35+
for i in range(dp_size):
36+
dp_tokens = num_tokens_across_dp_cpu[i]
37+
local_size[i] = min(max_num_tokens,
38+
dp_tokens - (max_num_tokens * chunk_idx))
39+
if local_size[i] <= 0:
40+
local_size[i] = 1 # ensure lockstep even if done
41+
return local_size
42+
43+
2944
@dataclass
3045
class DPMetadata:
3146
max_tokens_across_dp_cpu: torch.Tensor
3247
cu_tokens_across_dp_cpu: torch.Tensor
48+
local_sizes: Optional[list[int]] = None
3349

3450
@staticmethod
3551
def num_tokens_across_dp(num_tokens: int, dp_size: int,
@@ -78,6 +94,48 @@ def make(
7894
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
7995
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
8096

97+
@contextmanager
98+
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
99+
"""
100+
Context manager to compute and temporarily set the per-rank local token
101+
sizes for a specific chunk during chunked forward execution.
102+
103+
This is necessary to ensure each DP (data parallel) rank processes its
104+
designated portion of tokens in lockstep with others, even when the
105+
token counts are uneven or some ranks have completed their input early.
106+
107+
For chunked execution, we break up the total tokens on each rank into
108+
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
109+
`chunk_idx`, this context manager sets `self.local_sizes` to the number
110+
of tokens to process in that chunk on each rank.
111+
112+
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
113+
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
114+
to determine the chunk-wise split.
115+
116+
`self.local_sizes` is only valid inside the context.
117+
118+
Args:
119+
max_chunk_size_per_rank: The max number of tokens each rank is
120+
allowed to process in this chunk.
121+
chunk_idx: The index of the chunk to compute sizes for.
122+
"""
123+
cu_sizes = self.cu_tokens_across_dp_cpu
124+
num_tokens_across_dp_cpu = [
125+
(cu_sizes[i] -
126+
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
127+
for i in range(len(cu_sizes))
128+
]
129+
self.local_sizes = _compute_chunked_local_num_tokens(
130+
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
131+
try:
132+
yield self.local_sizes
133+
finally:
134+
self.local_sizes = None
135+
136+
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
137+
return self.local_sizes
138+
81139

82140
@dataclass
83141
class ForwardContext:

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
import vllm.envs as envs
87
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
98
from vllm.distributed import get_dp_group
109
from vllm.forward_context import get_forward_context
@@ -14,20 +13,8 @@
1413
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
1514

1615

17-
def get_local_sizes(local_tokens):
18-
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
19-
sizes = [cu_sizes[0].item()]
20-
for i in range(1, len(cu_sizes)):
21-
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
22-
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
23-
sizes_chunked = [max_num_tokens] * len(sizes)
24-
if local_tokens < max_num_tokens:
25-
# When the number of local tokens is less than max_num_tokens, all other
26-
# ranks will also have fewer than max_num_tokens. The remaining tokens
27-
# are accounted for as residual.
28-
sizes_chunked = [x % max_num_tokens for x in sizes]
29-
30-
return sizes_chunked
16+
def get_local_sizes():
17+
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
3118

3219

3320
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@@ -90,7 +77,7 @@ def prepare(
9077
topk_weights, topk_ids, a1q, a1q_scale = \
9178
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
9279
dim=0,
93-
sizes=get_local_sizes(local_tokens))
80+
sizes=get_local_sizes())
9481
a1_m, a1_n = a1q.shape
9582
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
9683

@@ -107,8 +94,5 @@ def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
10794
['use_dp', 'local_tokens'])
10895
if use_dp:
10996
fused_expert_output = get_dp_group().reduce_scatterv(
110-
fused_expert_output,
111-
dim=0,
112-
sizes=get_local_sizes(local_tokens),
113-
)
97+
fused_expert_output, dim=0, sizes=get_local_sizes())
11498
output.copy_(fused_expert_output)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,18 +1570,19 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
15701570
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
15711571
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
15721572
num_tokens = full_hidden_states.size(0)
1573-
for chunk_start_ in range(0, max_tokens_across_dp,
1574-
moe_dp_chunk_size_per_rank):
1573+
for chunk_idx, chunk_start_ in enumerate(
1574+
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
15751575
chunk_start = chunk_start_
15761576
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
15771577
max_tokens_across_dp)
15781578
# clamp start and end
15791579
chunk_start = min(chunk_start, num_tokens - 1)
15801580
chunk_end = min(chunk_end, num_tokens)
1581-
1582-
process_chunk(chunk_start,
1583-
chunk_end,
1584-
skip_result_store=chunk_start_ >= num_tokens)
1581+
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
1582+
chunk_idx):
1583+
process_chunk(chunk_start,
1584+
chunk_end,
1585+
skip_result_store=chunk_start_ >= num_tokens)
15851586

15861587
return full_final_hidden_states
15871588

0 commit comments

Comments
 (0)