This repository was archived by the owner on Sep 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +14
-15
lines changed Expand file tree Collapse file tree 1 file changed +14
-15
lines changed Original file line number Diff line number Diff line change @@ -1484,23 +1484,22 @@ def decode_one_token(
14841484 # Decode the output
14851485 if self .pp_rank == self .last_pp_rank :
14861486 new_token , _ = self .sample (logits , need_probs = need_probs , ** sampling_kwargs )
1487+ if self .pp_rank != self .first_pp_rank :
1488+ dist .send (
1489+ new_token ,
1490+ dst = self .first_pp_rank_global_id ,
1491+ group = self .pp_group ,
1492+ )
14871493 else :
14881494 new_token = torch .zeros (1 , 1 , device = self .device , dtype = torch .int64 )
1489-
1490- if self .pp_rank == self .last_pp_rank and self .pp_rank != self .first_pp_rank :
1491- dist .send (
1492- new_token ,
1493- dst = self .first_pp_rank_global_id ,
1494- group = self .pp_group ,
1495- )
1496- elif self .pp_rank == self .first_pp_rank and self .pp_rank != self .last_pp_rank :
1497- dist .recv (
1498- new_token ,
1499- src = self .last_pp_rank_global_id ,
1500- group = self .pp_group ,
1501- )
1502- #TODO: Why do we get 2d tensor here?
1503- new_token = new_token [0 ]
1495+ if self .pp_rank == self .first_pp_rank :
1496+ dist .recv (
1497+ new_token ,
1498+ src = self .last_pp_rank_global_id ,
1499+ group = self .pp_group ,
1500+ )
1501+ #TODO: Why do we get 2d tensor here?
1502+ new_token = new_token [0 ]
15041503 return new_token , None
15051504
15061505 def sample (
You can’t perform that action at this time.
0 commit comments