@@ -223,7 +223,7 @@ def _update_padded_sequence(
223223 for i in range (len (prompt_lengths )):
224224 prompt_lengths [i ] += 1
225225 padded_sequence [i , prompt_lengths [i ] - 1 ] = new_token [i , 0 ]
226- logger .info (f"updated prompt { i } with new token { new_token [i , 0 ]} " )
226+ # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
227227
228228
229229def _cleanup ():
@@ -296,7 +296,7 @@ def main(args):
296296 logger .info (f"Model: { model } " )
297297
298298 mbs = 1 # number of micro-batches
299- mb_size = 2 # micro-batch size
299+ mb_size = 5 # micro-batch size
300300 batch_size = mbs * mb_size # total batch size
301301
302302 seqlen = 4096 # sequence length
@@ -345,6 +345,9 @@ def main(args):
345345 prompt = [
346346 "What is snow?" ,
347347 "Where does Santa Claus live?" ,
348+ "What is PyTorch?" ,
349+ "Write a poem about the beauty of the night sky." ,
350+ "What is the capital of France, Germany and Switzerland?" ,
348351 ]
349352
350353 """
@@ -379,13 +382,11 @@ def main(args):
379382 input_ids = _encode_strings (
380383 prompt , tokenizer , bos = True , device = device , dtype = torch .int64
381384 )
382- logger .info (f"{ input_ids [0 ][0 :8 ]= } " )
383385
384386 # create a padded tensor for the input prompt
385387 padded_sequence , prompt_lengths = _create_padded_prompts (
386388 input_ids , tokenizer , seqlen , start_pos , device
387389 )
388- logger .info (f"length of each prompt in the batch: { prompt_lengths = } " )
389390
390391 # create schedule
391392 schedule = ScheduleGPipe (stage , mbs )
@@ -397,7 +398,7 @@ def main(args):
397398 # need a new token dimension (row) for each prompt in the batch
398399 new_token = torch .zeros (total_prompts , 1 , device = device , dtype = torch .int64 )
399400 res = [[] for _ in range (total_prompts )]
400- num_tokens = 20
401+ num_tokens = 40
401402
402403 # Decoding
403404 with torch .no_grad ():
@@ -449,18 +450,17 @@ def main(args):
449450 # Update input sequence with new token
450451 if pp_rank == first_pp_rank :
451452 _update_padded_sequence (padded_sequence , new_token , prompt_lengths )
452- for i in range (len (prompt_lengths )):
453- logger .info (
454- f"next submission: { padded_sequence [i , prompt_lengths [i ]- 4 :prompt_lengths [i ]+ 4 ]} "
455- )
453+
454+ # Display the decoding results
456455
457456 # output formatted response via last pp group and tp rank 0
458457 if pp_rank == last_pp_rank and tp_rank == 0 :
459458 for i in range (len (prompt_lengths )):
460- logger .info (f"Prompt :{ color .green } { prompt [i ]} { color .reset } " )
459+ logger .info (f"\n Prompt :{ color .green } { prompt [i ]} { color .reset } " )
461460 formatted_response = "" .join (res [i ])
462- logger .info (f"$$ { color .red } { formatted_response } { color .reset } $$" )
461+ logger .info (f"$$ { color .red } { formatted_response } { color .reset } $$\n " )
463462
463+ # Cleanup
464464 logger .info (
465465 f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } "
466466 )
0 commit comments