Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 68eec0b

Browse files
committed
Refactor conditions in pp
1 parent 28d7836 commit 68eec0b

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

torchchat/generate.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,23 +1484,22 @@ def decode_one_token(
14841484
# Decode the output
14851485
if self.pp_rank == self.last_pp_rank:
14861486
new_token, _ = self.sample(logits, need_probs=need_probs, **sampling_kwargs)
1487+
if self.pp_rank != self.first_pp_rank:
1488+
dist.send(
1489+
new_token,
1490+
dst=self.first_pp_rank_global_id,
1491+
group=self.pp_group,
1492+
)
14871493
else:
14881494
new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64)
1489-
1490-
if self.pp_rank == self.last_pp_rank and self.pp_rank != self.first_pp_rank:
1491-
dist.send(
1492-
new_token,
1493-
dst=self.first_pp_rank_global_id,
1494-
group=self.pp_group,
1495-
)
1496-
elif self.pp_rank == self.first_pp_rank and self.pp_rank != self.last_pp_rank:
1497-
dist.recv(
1498-
new_token,
1499-
src=self.last_pp_rank_global_id,
1500-
group=self.pp_group,
1501-
)
1502-
#TODO: Why do we get 2d tensor here?
1503-
new_token=new_token[0]
1495+
if self.pp_rank == self.first_pp_rank:
1496+
dist.recv(
1497+
new_token,
1498+
src=self.last_pp_rank_global_id,
1499+
group=self.pp_group,
1500+
)
1501+
#TODO: Why do we get 2d tensor here?
1502+
new_token=new_token[0]
15041503
return new_token, None
15051504

15061505
def sample(

0 commit comments

Comments
 (0)