diff --git a/dist_run.py b/dist_run.py index a46dd81f4..3a8f16a0b 100644 --- a/dist_run.py +++ b/dist_run.py @@ -384,6 +384,9 @@ def main(args): # with CUDATrackTime() as timer: first_pp_rank = 0 last_pp_rank = pp_group_size - 1 + # Need these global ids due to the API definition of dist.send and recv + first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) + last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) # New token generated each iteration new_token = torch.zeros(1, device=device, dtype=torch.int64) @@ -420,13 +423,13 @@ def main(args): if pp_rank == last_pp_rank and pp_rank != first_pp_rank: dist.send( new_token, - dst=dist.get_global_rank(pp_group, first_pp_rank), + dst=first_pp_rank_global_id, group=pp_group, ) elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: dist.recv( new_token, - src=dist.get_global_rank(pp_group, last_pp_rank), + src=last_pp_rank_global_id, group=pp_group, )