Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5f71825

Browse files
authored
Merge branch 'main' into lessw2020/batch_decoding
2 parents 5abb444 + 8b45633 commit 5f71825

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

dist_run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,11 @@ def main(args):
390390
schedule = ScheduleGPipe(stage, mbs)
391391

392392
# with CUDATrackTime() as timer:
393+
first_pp_rank = 0
394+
last_pp_rank = pp_group_size - 1
395+
# Need these global ids due to the API definition of dist.send and recv
396+
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
397+
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
393398

394399
# New token generated each iteration
395400
total_prompts = len(prompt_lengths)

0 commit comments

Comments
 (0)