-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] fix pp=1 case; clean up #1149
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1149
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 8189a4e with merge base a645f8e ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if pp_rank == last_pp_rank and pp_rank != first_pp_rank: | ||
| dist.send( | ||
| new_token, | ||
| dst=dist.get_global_rank(pp_group, first_pp_rank), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst and src shouldn't be getting recreated every iter, since they don't change on a per iter basis.
This is why I had moved them out of the loop previously.
Not sure how expensive the dist.get_global_rank is but no need imo to be calling it over and over here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Will upload a PR to improve it. Thanks!
| elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: | ||
| dist.recv( | ||
| new_token, | ||
| src=dist.get_global_rank(pp_group, last_pp_rank), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above w/dst, why do we call this api over and over to get src within the loop instead of once out of the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
Flagged one item re: why dst and scr got moved from out of loop, to in loop... seems no reason to recreate dst and src in every iter vs set once and keep them out of the loop. But, it's minor and prefer to land this now as planning to work on batch decoding today.
When there is only 1 PP stage, we can skip the sendrecv.
Plus some cleanup and renaming.