Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8b45633

Browse files
authored
[Distributed] Move global id calculation (#1150)
1 parent 7708646 commit 8b45633

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

dist_run.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ def main(args):
384384
# with CUDATrackTime() as timer:
385385
first_pp_rank = 0
386386
last_pp_rank = pp_group_size - 1
387+
# Need these global ids due to the API definition of dist.send and recv
388+
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
389+
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
387390

388391
# New token generated each iteration
389392
new_token = torch.zeros(1, device=device, dtype=torch.int64)
@@ -420,13 +423,13 @@ def main(args):
420423
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
421424
dist.send(
422425
new_token,
423-
dst=dist.get_global_rank(pp_group, first_pp_rank),
426+
dst=first_pp_rank_global_id,
424427
group=pp_group,
425428
)
426429
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
427430
dist.recv(
428431
new_token,
429-
src=dist.get_global_rank(pp_group, last_pp_rank),
432+
src=last_pp_rank_global_id,
430433
group=pp_group,
431434
)
432435

0 commit comments

Comments
 (0)