Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 10fb55a

Browse files
mresoJack-Khuu
andauthored
Apply suggestions from code review
Co-authored-by: Jack-Khuu <[email protected]>
1 parent 7cb98c9 commit 10fb55a

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

torchchat/generate.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ def __del__(self):
12871287
dist.destroy_process_group()
12881288

12891289
# Helper function to get example inputs and outputs for the stages.
1290-
def get_example_ins_outs(self, batch_size:int , seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
1290+
def get_example_ins_outs(self, batch_size: int , seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
12911291
"""
12921292
This function generates example inputs and outputs for the prefill and decode stages.
12931293
@@ -1308,9 +1308,7 @@ def get_example_ins_outs(self, batch_size:int , seqlen: int) -> Tuple[torch.Tens
13081308
example_outputs = (logits if self.pp_rank == self.last_pp_rank else activation,)
13091309
return example_inputs, example_outputs
13101310

1311-
def create_prefill_stage(
1312-
self,
1313-
):
1311+
def create_prefill_stage(self):
13141312
"""
13151313
Creates a pipeline stage for prefilling.
13161314
@@ -1340,9 +1338,7 @@ def create_prefill_stage(
13401338
prefiller = ScheduleGPipe(prefill_stage, 1)
13411339
return prefiller
13421340

1343-
def create_decode_stage(
1344-
self,
1345-
):
1341+
def create_decode_stage(self):
13461342
"""
13471343
Creates a decode stage for the pipeline parallelism.
13481344
@@ -1422,24 +1418,22 @@ def prefill(
14221418
else: # middle pp ranks
14231419
self.prefiller.step(**kwargs)
14241420

1425-
new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64)
1426-
14271421
if self.pp_rank == self.last_pp_rank:
14281422
new_token = self.sample(logits[:,:prompt_length], need_probs=False, **sampling_kwargs)[0]
1429-
1430-
1431-
if self.pp_rank == self.last_pp_rank and self.pp_rank != self.first_pp_rank:
1432-
dist.send(
1433-
new_token,
1434-
dst=self.first_pp_rank_global_id,
1435-
group=self.pp_group,
1436-
)
1437-
elif self.pp_rank == self.first_pp_rank and self.pp_rank != self.last_pp_rank:
1438-
dist.recv(
1439-
new_token,
1440-
src=self.last_pp_rank_global_id,
1441-
group=self.pp_group,
1442-
)
1423+
if self.pp_rank != self.first_pp_rank:
1424+
dist.send(
1425+
new_token,
1426+
dst=self.first_pp_rank_global_id,
1427+
group=self.pp_group,
1428+
)
1429+
else:
1430+
new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64)
1431+
if self.pp_rank == self.first_pp_rank:
1432+
dist.recv(
1433+
new_token,
1434+
src=self.last_pp_rank_global_id,
1435+
group=self.pp_group,
1436+
)
14431437

14441438
return new_token
14451439

@@ -1485,7 +1479,7 @@ def decode_one_token(
14851479

14861480
# Decode the output
14871481
if self.pp_rank == self.last_pp_rank:
1488-
new_token, next_prob = self.sample(logits, need_probs=need_probs, **sampling_kwargs)
1482+
new_token, _ = self.sample(logits, need_probs=need_probs, **sampling_kwargs)
14891483
else:
14901484
new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64)
14911485

0 commit comments

Comments
 (0)