Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit dfe6fe2

Browse files
committed
use Ke's variable names for send/rcv
1 parent 8b41ae2 commit dfe6fe2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

dist_run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,9 @@ def main(args):
373373
first_pp_rank = 0
374374
last_pp_rank = pp_group_size - 1
375375

376-
send_destination = dist.get_global_rank(pp_group, first_pp_rank)
377-
recv_source = dist.get_global_rank(pp_group, last_pp_rank)
376+
# Need these global ids due to the API definition of dist.send and recv
377+
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
378+
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
378379

379380
# encode the prompt
380381
input_ids = _encode_strings(
@@ -435,13 +436,13 @@ def main(args):
435436
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
436437
dist.send(
437438
new_token,
438-
dst=send_destination,
439+
dst=first_pp_rank_global_id,
439440
group=pp_group,
440441
)
441442
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
442443
dist.recv(
443444
new_token,
444-
src=recv_source,
445+
src=last_pp_rank_global_id,
445446
group=pp_group,
446447
)
447448

0 commit comments

Comments
 (0)