Skip to content

Commit 9233d83

Browse files
authored
[EP] bug fixes (#1586)
fixes bug introduced in #1555
1 parent 72b16b1 commit 9233d83

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,13 +365,18 @@ def wrapper(
365365
# This class is to support Sequence Parallel for ETP=1
366366
# when EP borrows from all TP and part of DP
367367
class ReordererSequenceParallel(ParallelStyle):
368+
def __init__(self):
369+
super().__init__()
370+
self.num_tokens = None
371+
368372
def _prepare_inputput_fn(self, mod, inputs, device_mesh):
369373
top_scores, selected_experts_indices = inputs
370374

371375
top_scores = DTensor.from_local(top_scores, device_mesh, (Replicate(),))
372376
selected_experts_indices = DTensor.from_local(
373377
selected_experts_indices, device_mesh, (Replicate(),)
374378
)
379+
self.num_tokens = top_scores.shape[0]
375380

376381
# TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
377382
# if top_scores.shape[0] % device_mesh.size() != 0:
@@ -380,7 +385,7 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh):
380385
# n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens
381386
# selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad])
382387
# top_scores = F.pad(top_scores, [0, 0, 0, n_pad])
383-
assert top_scores.shape[0] % device_mesh.size() == 0
388+
assert self.num_tokens % device_mesh.size() == 0
384389

385390
# split on the bs*slen dimension
386391
top_scores = top_scores.redistribute(device_mesh, (Shard(0),)).to_local()
@@ -395,9 +400,10 @@ def _prepare_output_fn(self, mod, outputs, device_mesh):
395400

396401
# NOTE: As we shard routed tokens along bs*slen dim across the TP ranks,
397402
# the MoE gather and scatter still require global token indices.
398-
num_tokens = top_scores.shape[0]
399403
local_rank = device_mesh.get_local_rank()
400-
token_indices_experts_sorted += num_tokens // device_mesh.size() * local_rank
404+
token_indices_experts_sorted += (
405+
self.num_tokens // device_mesh.size() * local_rank
406+
)
401407

402408
return top_scores, token_indices_experts_sorted, num_tokens_per_expert
403409

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def apply_moe_ep_tp(
401401
# replicate computation for the router
402402
"moe.router.gate": NoParallel(),
403403
}
404-
if not etp_enabled:
404+
if ep_mesh is not None and not etp_enabled:
405405
# If TP is borrowed for EP, then split the tokens across TP ranks so that
406406
# the reorderer, the all-to-all comms, and routed experts computation
407407
# are effectively running Sequence Parallel (split along the folded bs*slen dim)

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ fsdp_reshard_after_forward = "default" # default / never / always
5151
tensor_parallel_degree = 1
5252
enable_async_tensor_parallel = false
5353
pipeline_parallel_degree = 1
54+
pipeline_parallel_schedule = "1F1B"
5455
context_parallel_degree = 1
5556
expert_parallel_degree = 1
5657
expert_tensor_parallel_degree = 1

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ tensor_parallel_degree = 1
5252
enable_async_tensor_parallel = false
5353
pipeline_parallel_degree = 1
5454
pipeline_parallel_schedule = "1F1B"
55+
context_parallel_degree = 1
5556
expert_parallel_degree = 1
5657
expert_tensor_parallel_degree = 1
5758

0 commit comments

Comments
 (0)