Skip to content

Commit 0d1b80d

Browse files
authored
[EP] remove token split overhead from DTensor in TokenReorderer pre hook (#1587)
Due to the d2h sync in EP, training sometimes is CPU bounded. So we need to be more careful about DTensor overhead. See screenshots below for profiler traces. Numerics are verified to be the same. **before** <img width="864" height="778" alt="image" src="https://github.com/user-attachments/assets/7c6f5133-7747-474b-8951-3e392ed92b28" /> **after** <img width="441" height="743" alt="image" src="https://github.com/user-attachments/assets/0c9a274e-84b7-4c2b-aeff-395b06d01410" />
1 parent 9233d83 commit 0d1b80d

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -371,27 +371,28 @@ def __init__(self):
371371

372372
def _prepare_inputput_fn(self, mod, inputs, device_mesh):
373373
top_scores, selected_experts_indices = inputs
374-
375-
top_scores = DTensor.from_local(top_scores, device_mesh, (Replicate(),))
376-
selected_experts_indices = DTensor.from_local(
377-
selected_experts_indices, device_mesh, (Replicate(),)
378-
)
379374
self.num_tokens = top_scores.shape[0]
380375

381-
# TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
376+
# NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
382377
# if top_scores.shape[0] % device_mesh.size() != 0:
383378
# num_tokens = top_scores.shape[0]
384379
# tp_size = device_mesh.size()
385380
# n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens
386381
# selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad])
387382
# top_scores = F.pad(top_scores, [0, 0, 0, n_pad])
388-
assert self.num_tokens % device_mesh.size() == 0
389383

390-
# split on the bs*slen dimension
391-
top_scores = top_scores.redistribute(device_mesh, (Shard(0),)).to_local()
392-
selected_experts_indices = selected_experts_indices.redistribute(
393-
device_mesh, (Shard(0),)
394-
).to_local()
384+
def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
385+
assert x.is_contiguous()
386+
assert self.num_tokens % device_mesh.size() == 0
387+
local_num_tokens = self.num_tokens // device_mesh.size()
388+
local_rank = device_mesh.get_local_rank()
389+
offset = local_rank * local_num_tokens
390+
output = x[offset : offset + local_num_tokens]
391+
392+
return output
393+
394+
top_scores = _split_along_first_dim(top_scores)
395+
selected_experts_indices = _split_along_first_dim(selected_experts_indices)
395396

396397
return top_scores, selected_experts_indices
397398

0 commit comments

Comments
 (0)