@@ -1287,7 +1287,7 @@ def __del__(self):
12871287 dist .destroy_process_group ()
12881288
12891289 # Helper function to get example inputs and outputs for the stages.
1290- def get_example_ins_outs (self , batch_size :int , seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
1290+ def get_example_ins_outs (self , batch_size : int , seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
12911291 """
12921292 This function generates example inputs and outputs for the prefill and decode stages.
12931293
@@ -1308,9 +1308,7 @@ def get_example_ins_outs(self, batch_size:int , seqlen: int) -> Tuple[torch.Tens
13081308 example_outputs = (logits if self .pp_rank == self .last_pp_rank else activation ,)
13091309 return example_inputs , example_outputs
13101310
1311- def create_prefill_stage (
1312- self ,
1313- ):
1311+ def create_prefill_stage (self ):
13141312 """
13151313 Creates a pipeline stage for prefilling.
13161314
@@ -1340,9 +1338,7 @@ def create_prefill_stage(
13401338 prefiller = ScheduleGPipe (prefill_stage , 1 )
13411339 return prefiller
13421340
1343- def create_decode_stage (
1344- self ,
1345- ):
1341+ def create_decode_stage (self ):
13461342 """
13471343 Creates a decode stage for the pipeline parallelism.
13481344
@@ -1422,24 +1418,22 @@ def prefill(
14221418 else : # middle pp ranks
14231419 self .prefiller .step (** kwargs )
14241420
1425- new_token = torch .zeros (1 , 1 , device = self .device , dtype = torch .int64 )
1426-
14271421 if self .pp_rank == self .last_pp_rank :
14281422 new_token = self .sample (logits [:,:prompt_length ], need_probs = False , ** sampling_kwargs )[0 ]
1429-
1430-
1431- if self . pp_rank == self . last_pp_rank and self . pp_rank != self . first_pp_rank :
1432- dist . send (
1433- new_token ,
1434- dst = self . first_pp_rank_global_id ,
1435- group = self . pp_group ,
1436- )
1437- elif self . pp_rank == self . first_pp_rank and self .pp_rank != self .last_pp_rank :
1438- dist .recv (
1439- new_token ,
1440- src = self .last_pp_rank_global_id ,
1441- group = self .pp_group ,
1442- )
1423+ if self . pp_rank != self . first_pp_rank :
1424+ dist . send (
1425+ new_token ,
1426+ dst = self . first_pp_rank_global_id ,
1427+ group = self . pp_group ,
1428+ )
1429+ else :
1430+ new_token = torch . zeros ( 1 , 1 , device = self . device , dtype = torch . int64 )
1431+ if self .pp_rank == self .first_pp_rank :
1432+ dist .recv (
1433+ new_token ,
1434+ src = self .last_pp_rank_global_id ,
1435+ group = self .pp_group ,
1436+ )
14431437
14441438 return new_token
14451439
@@ -1485,7 +1479,7 @@ def decode_one_token(
14851479
14861480 # Decode the output
14871481 if self .pp_rank == self .last_pp_rank :
1488- new_token , next_prob = self .sample (logits , need_probs = need_probs , ** sampling_kwargs )
1482+ new_token , _ = self .sample (logits , need_probs = need_probs , ** sampling_kwargs )
14891483 else :
14901484 new_token = torch .zeros (1 , 1 , device = self .device , dtype = torch .int64 )
14911485
0 commit comments