@@ -219,10 +219,11 @@ def _update_padded_sequence(
219219    new_token : torch .Tensor ,
220220    prompt_lengths : List [int ],
221221) ->  None :
222-     # TODO: this is a hacky way to update the padded sequence: when there is 
223-     # more than one prompt, the for loop and the assignment is incompatible. 
222+ 
224223    for  i  in  range (len (prompt_lengths )):
225-         padded_sequence [i , prompt_lengths [i ]] =  new_token 
224+         prompt_lengths [i ] +=  1 
225+         padded_sequence [i , prompt_lengths [i ] -  1 ] =  new_token [i , 0 ]
226+         logger .info (f"updated prompt { i }   with new token { new_token [i , 0 ]}  " )
226227
227228
228229def  _cleanup ():
@@ -242,7 +243,7 @@ def main(args):
242243    distribution , model_dtype  =  NAME_TO_DISTRIBUTION_AND_DTYPE [model_name ]
243244    logger .info (f"Using HF model weights from { distribution }   and dtype { model_dtype }  " )
244245
245-     config  =  ModelArgs .from_name (distribution ).transformer_args [' text'  ]
246+     config  =  ModelArgs .from_name (distribution ).transformer_args [" text"  ]
246247    logger .info (f"Chat Model Config: { config }  " )
247248
248249    tokenizer  =  _build_chat_tokenizer (model_name )
@@ -295,7 +296,7 @@ def main(args):
295296        logger .info (f"Model: { model }  " )
296297
297298    mbs  =  1   # number of micro-batches 
298-     mb_size  =  1   # micro-batch size 
299+     mb_size  =  2   # micro-batch size 
299300    batch_size  =  mbs  *  mb_size   # total batch size 
300301
301302    seqlen  =  4096   # sequence length 
@@ -343,6 +344,7 @@ def main(args):
343344
344345    prompt  =  [
345346        "What is snow?" ,
347+         "Where does Santa Claus live?" ,
346348    ]
347349
348350    """ 
@@ -366,29 +368,36 @@ def main(args):
366368
367369    start_pos  =  0 
368370
371+     # pipeline comms setup 
372+     first_pp_rank  =  0 
373+     last_pp_rank  =  pp_group_size  -  1 
374+ 
375+     send_destination  =  dist .get_global_rank (pp_group , first_pp_rank )
376+     recv_source  =  dist .get_global_rank (pp_group , last_pp_rank )
377+ 
369378    # encode the prompt 
370379    input_ids  =  _encode_strings (
371380        prompt , tokenizer , bos = True , device = device , dtype = torch .int64 
372381    )
373-     logger .info (f"{ input_ids [0 :8 ]= }  " )
382+     logger .info (f"{ input_ids [0 ][ 0 :8 ]= }  " )
374383
375384    # create a padded tensor for the input prompt 
376385    padded_sequence , prompt_lengths  =  _create_padded_prompts (
377386        input_ids , tokenizer , seqlen , start_pos , device 
378387    )
379-     logger .info (f"{ prompt_lengths = }  " )
388+     logger .info (f"length of each prompt in the batch:  { prompt_lengths = }  " )
380389
381390    # create schedule 
382391    schedule  =  ScheduleGPipe (stage , mbs )
383392
384393    # with CUDATrackTime() as timer: 
385-     first_pp_rank  =  0 
386-     last_pp_rank  =  pp_group_size  -  1 
387394
388395    # New token generated each iteration 
389-     new_token  =  torch .zeros (1 , device = device , dtype = torch .int64 )
390-     res  =  []
391-     num_tokens  =  40 
396+     total_prompts  =  len (prompt_lengths )
397+     # need a new token dimension (row) for each prompt in the batch 
398+     new_token  =  torch .zeros (total_prompts , 1 , device = device , dtype = torch .int64 )
399+     res  =  [[] for  _  in  range (total_prompts )]
400+     num_tokens  =  20 
392401
393402    # Decoding 
394403    with  torch .no_grad ():
@@ -412,39 +421,45 @@ def main(args):
412421                        f"responses ====>>>> { color .blue }   { decode_results = } { color .reset }  " 
413422                    )
414423                # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] 
415-                 new_token  =  torch .tensor ([decode_results [0 ][0 ]], device = device )
416-                 res .append (decode_results [0 ][1 ])
424+                 for  i  in  range (len (decode_results )):
425+                     res [i ].append (decode_results [i ][1 ])
426+                     new_token [i , 0 ] =  torch .tensor (
427+                         [decode_results [i ][0 ]], device = device 
428+                     )  # decode_results[i][0] 
429+ 
430+                 # increment prompt lengths for next token 
431+                 for  i  in  range (len (prompt_lengths )):
432+                     prompt_lengths [i ] +=  1 
417433
418434            # sendrecv between last and first ranks, only if: 
419435            # first_pp_rank != last_pp_rank. 
420436            if  pp_rank  ==  last_pp_rank  and  pp_rank  !=  first_pp_rank :
421437                dist .send (
422438                    new_token ,
423-                     dst = dist . get_global_rank ( pp_group ,  first_pp_rank ) ,
439+                     dst = send_destination ,
424440                    group = pp_group ,
425441                )
426442            elif  pp_rank  ==  first_pp_rank  and  pp_rank  !=  last_pp_rank :
427443                dist .recv (
428444                    new_token ,
429-                     src = dist . get_global_rank ( pp_group ,  last_pp_rank ) ,
445+                     src = recv_source ,
430446                    group = pp_group ,
431447                )
432448
433449            # Update input sequence with new token 
434450            if  pp_rank  ==  first_pp_rank :
435-                 _update_padded_sequence (
436-                     padded_sequence , new_token , prompt_lengths 
437-                 )
438- 
439-             # increment prompt lengths for next token 
440-             for  i  in  range (len (prompt_lengths )):
441-                 prompt_lengths [i ] +=  1 
451+                 _update_padded_sequence (padded_sequence , new_token , prompt_lengths )
452+                 for  i  in  range (len (prompt_lengths )):
453+                     logger .info (
454+                         f"next submission: { padded_sequence [i , prompt_lengths [i ]- 4 :prompt_lengths [i ]+ 4 ]}  " 
455+                     )
442456
443457    # output formatted response via last pp group and tp rank 0 
444458    if  pp_rank  ==  last_pp_rank  and  tp_rank  ==  0 :
445-         logger .info (f"Prompt:{ color .green }   { prompt [0 ]}   { color .reset }  " )
446-         formatted_response  =  " " .join (res )
447-         logger .info (f"$$$$$$ { color .blue } { formatted_response }   { color .reset }    $$$$$" )
459+         for  i  in  range (len (prompt_lengths )):
460+             logger .info (f"Prompt:{ color .green }   { prompt [i ]}   { color .reset }  " )
461+             formatted_response  =  "" .join (res [i ])
462+             logger .info (f"$$ { color .red } { formatted_response }   { color .reset }    $$" )
448463
449464    logger .info (
450465        f"{ color .green }  Success{ color .white }   - { color .blue }  Rank { rank }   has completed.{ color .reset }  " 
@@ -454,7 +469,12 @@ def main(args):
454469
455470if  __name__  ==  "__main__" :
456471    parser  =  argparse .ArgumentParser ()
457-     parser .add_argument ("model_name" , type = str , help = "Name of the model to load" , choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys ())
472+     parser .add_argument (
473+         "model_name" ,
474+         type = str ,
475+         help = "Name of the model to load" ,
476+         choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys (),
477+     )
458478    parser .add_argument ("--pp" , type = int , default = 1 , help = "Pipeline parallel degree" )
459479    args  =  parser .parse_args ()
460480
0 commit comments