@@ -226,7 +226,7 @@ def _batch_decode_next_tokens_new(
226226
227227
228228def _batch_decode_next_tokens (
229- output : torch .Tensor , pos : int , step : int = - 1
229+ output : torch .Tensor , pos : List [ int ] , step : int = - 1 , temperature : float = 1.0
230230) -> torch .Tensor :
231231 """
232232 Decode the next token for each prompt in the batch.
@@ -239,19 +239,29 @@ def _batch_decode_next_tokens(
239239 """
240240 # Take the next token logits for each prompt
241241 res = []
242- logger .info (f"{ color .green } output shape = { output .shape } { color .reset } " )
243- logger .info (f"{ color .green } pos = { pos } { color .reset } " )
244- for i in range (output .shape [0 ]):
245- token_pos = 0 if step != - 1 else pos [i ] - 1
246- next_token_logits = output [i , token_pos , :]
242+ # logger.info(f"{color.green}output shape = {output.shape}{color.reset}")
243+ # logger.info(f"{color.green}pos = {pos}{color.reset}")
244+ batch_size , seq_len , vocab_size = output .shape
247245
248- # Argmax (deterministic) TODO: add temperature
246+ if step != - 1 :
247+ next_token_logits = output [:, 0 , :]
249248 next_token = torch .argmax (next_token_logits , dim = - 1 )
250- logger .info (f"{ color .blue } next_token = { next_token } { color .reset } " )
251249 res .append (next_token )
252- # Token ids in int tensor form
253- res = torch .stack (res , dim = 0 )
254- logger .info (f"{ color .green } next_token = { res } { color .reset } " )
250+ res = torch .stack (res , dim = 0 )
251+ res = res .squeeze (0 )
252+ else :
253+ for i in range (batch_size ):
254+ token_pos = pos [i ] - 1
255+ next_token_logits = output [i , token_pos , :]
256+
257+ # Argmax (deterministic) TODO: add temperature
258+ next_token = torch .argmax (next_token_logits , dim = - 1 )
259+ # logger.info(f"{color.blue}next_token = {next_token}{color.reset}")
260+ res .append (next_token )
261+ # Token ids in int tensor form
262+ res = torch .stack (res , dim = 0 )
263+
264+ logger .info (f"{ color .yellow } next_token = { res } { color .reset } " )
255265 return res # next_token
256266
257267
@@ -340,7 +350,7 @@ def main(args):
340350 # Batch size. Since we push batches dynamically through the pipeline rather
341351 # than chunking them, this is effectively micro-batch size in pipeline
342352 # sense. Thus it is interchangeable with micro-batch size below.
343- batch_size = 2
353+ batch_size = 3
344354 seqlen_prefill = 1024 # sequence length
345355 dim = 4096 # embedding dimension
346356
@@ -414,7 +424,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
414424 prompt = [
415425 "What is Snow?" ,
416426 "Who is Santa Claus?" ,
417- # "Where does Santa live?",
427+ "Where does Santa live?" ,
418428 # "Who is Abraham Lincoln?",
419429 # "How are models trained?",
420430 ]
@@ -455,7 +465,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
455465 # Run context input through pipeline
456466 # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
457467 lane = 0
458- logger .info (f"{ color .green } Prefilling...{ input_pos = } { color .reset } " )
459468 kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
460469 with torch .no_grad (), CUDATrackTime () as timer :
461470 if pp_rank == first_pp_rank :
0 commit comments