3131 get_num_params ,
3232 GPUMemoryMonitor ,
3333)
34- from distributed .verification_utils import find_cpu_tensors
3534from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
3635from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
3736from torchchat .model import ModelArgs , Transformer
@@ -219,10 +218,9 @@ def _update_padded_sequence(
219218 new_token : torch .Tensor ,
220219 prompt_lengths : List [int ],
221220) -> None :
222- # TODO: this is a hacky way to update the padded sequence: when there is
223- # more than one prompt, the for loop and the assignment is incompatible.
224221 for i in range (len (prompt_lengths )):
225- padded_sequence [i , prompt_lengths [i ]] = new_token
222+ padded_sequence [i , prompt_lengths [i ]] = new_token [i , 0 ]
223+ # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
226224
227225
228226def _cleanup ():
@@ -242,7 +240,7 @@ def main(args):
242240 distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [model_name ]
243241 logger .info (f"Using HF model weights from { distribution } and dtype { model_dtype } " )
244242
245- config = ModelArgs .from_name (distribution ).transformer_args [' text' ]
243+ config = ModelArgs .from_name (distribution ).transformer_args [" text" ]
246244 logger .info (f"Chat Model Config: { config } " )
247245
248246 tokenizer = _build_chat_tokenizer (model_name )
@@ -295,7 +293,7 @@ def main(args):
295293 logger .info (f"Model: { model } " )
296294
297295 mbs = 1 # number of micro-batches
298- mb_size = 1 # micro-batch size
296+ mb_size = 5 # micro-batch size
299297 batch_size = mbs * mb_size # total batch size
300298
301299 seqlen = 4096 # sequence length
@@ -343,6 +341,10 @@ def main(args):
343341
344342 prompt = [
345343 "What is snow?" ,
344+ "Where does Santa Claus live?" ,
345+ "What is PyTorch?" ,
346+ "Write a poem about the beauty of the night sky." ,
347+ "What is the capital of France, Germany and Switzerland?" ,
346348 ]
347349
348350 """
@@ -366,28 +368,39 @@ def main(args):
366368
367369 start_pos = 0
368370
371+ # pipeline comms setup
372+ first_pp_rank = 0
373+ last_pp_rank = pp_group_size - 1
374+
375+ # Need these global ids due to the API definition of dist.send and recv
376+ first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
377+ last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
378+
369379 # encode the prompt
370380 input_ids = _encode_strings (
371381 prompt , tokenizer , bos = True , device = device , dtype = torch .int64
372382 )
373- logger .info (f"{ input_ids [0 :8 ]= } " )
374383
375384 # create a padded tensor for the input prompt
376385 padded_sequence , prompt_lengths = _create_padded_prompts (
377386 input_ids , tokenizer , seqlen , start_pos , device
378387 )
379- logger .info (f"{ prompt_lengths = } " )
380388
381389 # create schedule
382390 schedule = ScheduleGPipe (stage , mbs )
383391
384392 # with CUDATrackTime() as timer:
385393 first_pp_rank = 0
386394 last_pp_rank = pp_group_size - 1
395+ # Need these global ids due to the API definition of dist.send and recv
396+ first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
397+ last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
387398
388399 # New token generated each iteration
389- new_token = torch .zeros (1 , device = device , dtype = torch .int64 )
390- res = []
400+ total_prompts = len (prompt_lengths )
401+ # need a new token dimension (row) for each prompt in the batch
402+ new_token = torch .zeros (total_prompts , 1 , device = device , dtype = torch .int64 )
403+ res = [[] for _ in range (total_prompts )]
391404 num_tokens = 40
392405
393406 # Decoding
@@ -412,40 +425,50 @@ def main(args):
412425 f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
413426 )
414427 # decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
415- new_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
416- res .append (decode_results [0 ][1 ])
428+ for i in range (len (decode_results )):
429+ res [i ].append (decode_results [i ][1 ])
430+ new_token [i , 0 ] = torch .tensor (
431+ [decode_results [i ][0 ]], device = device
432+ ) # decode_results[i][0]
417433
418434 # sendrecv between last and first ranks, only if:
419435 # first_pp_rank != last_pp_rank.
420436 if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
421437 dist .send (
422438 new_token ,
423- dst = dist . get_global_rank ( pp_group , first_pp_rank ) ,
439+ dst = first_pp_rank_global_id ,
424440 group = pp_group ,
425441 )
426442 elif pp_rank == first_pp_rank and pp_rank != last_pp_rank :
427443 dist .recv (
428444 new_token ,
429- src = dist . get_global_rank ( pp_group , last_pp_rank ) ,
445+ src = last_pp_rank_global_id ,
430446 group = pp_group ,
431447 )
432448
433449 # Update input sequence with new token
434450 if pp_rank == first_pp_rank :
435- _update_padded_sequence (
436- padded_sequence , new_token , prompt_lengths
437- )
451+ _update_padded_sequence (padded_sequence , new_token , prompt_lengths )
438452
439453 # increment prompt lengths for next token
440454 for i in range (len (prompt_lengths )):
441455 prompt_lengths [i ] += 1
442456
457+ # Display the decoding results
458+
443459 # output formatted response via last pp group and tp rank 0
444460 if pp_rank == last_pp_rank and tp_rank == 0 :
445- logger .info (f"Prompt:{ color .green } { prompt [0 ]} { color .reset } " )
446- formatted_response = " " .join (res )
447- logger .info (f"$$$$$$ { color .blue } { formatted_response } { color .reset } $$$$$" )
461+ for i in range (len (prompt_lengths )):
462+ logger .info (f"\n Prompt:{ color .green } { prompt [i ]} { color .reset } " )
448463
464+ # TODO: resolve issue with llama2-7b-chat model and "".join
465+ if model_name != "llama2-7b-chat" :
466+ formatted_response = "" .join (res [i ])
467+ else :
468+ formatted_response = " " .join (res [i ])
469+ logger .info (f"$$ { color .red } { formatted_response } { color .reset } $$\n " )
470+
471+ # Cleanup
449472 logger .info (
450473 f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } "
451474 )
@@ -454,7 +477,12 @@ def main(args):
454477
455478if __name__ == "__main__" :
456479 parser = argparse .ArgumentParser ()
457- parser .add_argument ("model_name" , type = str , help = "Name of the model to load" , choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys ())
480+ parser .add_argument (
481+ "model_name" ,
482+ type = str ,
483+ help = "Name of the model to load" ,
484+ choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys (),
485+ )
458486 parser .add_argument ("--pp" , type = int , default = 1 , help = "Pipeline parallel degree" )
459487 args = parser .parse_args ()
460488
0 commit comments