@@ -187,8 +187,8 @@ def _create_padded_prompts(
187187
188188def  _batch_decode_next_tokens (
189189    output : torch .Tensor ,
190-     prompt_lengths : List [int ],
191190    tokenizer ,
191+     prompt_lengths : Optional [List [int ]] =  None ,
192192) ->  List [Tuple [int , str ]]:
193193    """ 
194194    Decode the next token for each prompt in the batch. 
@@ -201,7 +201,8 @@ def _batch_decode_next_tokens(
201201    results  =  []
202202
203203    for  i  in  range (batch_size ):
204-         next_token_logits  =  output [i , prompt_lengths [i ] -  1 , :]
204+         pos  =  prompt_lengths [i ] -  1  if  prompt_lengths  is  not None  else  0 
205+         next_token_logits  =  output [i , pos , :]
205206
206207        # Argmax (deterministic) TODO: add temperature 
207208        next_token  =  torch .argmax (next_token_logits , dim = - 1 )
@@ -276,6 +277,10 @@ def main(args):
276277    tp_group_size  =  tp_group .size ()
277278    logger .info (f"{ pp_group_size = } { tp_group_size = }  )
278279
280+     # Convenience variables 
281+     first_pp_rank  =  0 
282+     last_pp_rank  =  pp_group_size  -  1 
283+ 
279284    # Assuming same number of GPUs per node 
280285    device  =  torch .device (f"cuda:{ rank  %  torch .cuda .device_count ()}  )
281286
@@ -293,29 +298,23 @@ def main(args):
293298        logger .info (f"Model: { model }  )
294299
295300    mbs  =  1   # number of micro-batches 
296-     mb_size  =  5   # micro-batch size 
301+     mb_size  =  4   # micro-batch size 
297302    batch_size  =  mbs  *  mb_size   # total batch size 
298303
299-     seqlen  =  4096   # sequence length 
304+     seqlen_prefill  =  1024   # sequence length 
300305    dim  =  4096   # embedding dimension 
301306
302307    # Setup KV caches (after model distribution) 
303308    # TODO: the setting below only works for 1 micro-batch case. To support 
304309    # multiple micro-batches, we need the KV cache in the model to be aware of 
305310    # the number of micro-batches and the current micro-batch index. 
306-     model .setup_caches (mb_size , seqlen )
307- 
308-     mb_ids  =  torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
309-     activation  =  torch .rand (
310-         mb_size , seqlen , dim , device = device , dtype = model_dtype 
311-     )
312-     example_args  =  mb_ids  if  pp_rank  ==  0  else  activation 
311+     model .setup_caches (mb_size , seqlen_prefill )
313312
314313    # Load weights 
315314    logger .info (f"Loading weights for { pp_rank = } { device = }  )
316- 
317315    with  CUDATrackTime () as  timer :
318316        _load_model_weights (model , distribution , device = device , model_config = config )
317+         model .to (device )
319318
320319    logger .info (
321320        f"{ color .green } { timer .get_time ()} { timer .unit } { rank } { color .reset }  
@@ -330,53 +329,47 @@ def main(args):
330329    )
331330
332331    # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen 
333-     input_pos  =  torch .arange (seqlen , device = device )
332+     input_pos  =  torch .arange (seqlen_prefill , device = device )
334333    model .setup_input_pos (input_pos )
335334    model .eval ()
336335
337-     logger .info (f"Creating pipeline stage { pp_rank = } { pp_degree = }  )
338-     stage  =  PipelineStage (
336+     # Helper function to get example inputs and outputs for the stages. 
337+     def  get_example_ins_outs (seqlen : int ) ->  Tuple [torch .Tensor , torch .Tensor ]:
338+         mb_ids  =  torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
339+         activation  =  torch .rand (
340+             mb_size , seqlen , dim , device = device , dtype = model_dtype 
341+         )
342+         logits  =  torch .rand (
343+             mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype 
344+         )
345+         example_inputs  =  (mb_ids  if  pp_rank  ==  first_pp_rank  else  activation ,)
346+         example_outputs  =  (logits  if  pp_rank  ==  last_pp_rank  else  activation ,)
347+         return  example_inputs , example_outputs 
348+ 
349+     # Create prefill stage 
350+     logger .info (f"Creating pipeline stage for prefill { pp_rank = } { pp_degree = }  )
351+     example_inputs , example_outputs  =  get_example_ins_outs (seqlen_prefill )
352+     prefill_stage  =  PipelineStage (
339353        model ,
340354        pp_rank ,
341355        pp_degree ,
342356        device ,
343-         input_args = (example_args ,),
357+         input_args = example_inputs ,
358+         output_args = example_outputs ,
344359        group = pp_group ,
345360    )
361+     # create schedule 
362+     prefill_schedule  =  ScheduleGPipe (prefill_stage , mbs )
346363
347364    prompt  =  [
348-         "What is snow?" ,
349-         "Where does Santa Claus live?" ,
350-         "What is PyTorch?" ,
351-         "Write a poem about the beauty of the night sky." ,
352-         "What is the capital of France, Germany and Switzerland?" ,
353-     ]
354- 
355-     """ 
356-     "What is the capital of France?", 
357-         "What is your name?", 
358-         "What is the capital of Japan?", 
359-         "When is Christmas?", 
360-         "Where does Santa Claus live?", 
361-         "What is the capital of the United States?", 
362-         "What is the capital of China?", 
363-         "What is the capital of Russia?", 
364-         "What is PyTorch?", 
365-         "What is the capital of India?", 
366-         "What is an LLM?", 
367-         "What is the capital of Brazil?", 
368-         "What is the capital of Mexico?", 
369-         "What is the capital of Argentina?", 
370-         "What is the capital of Canada?", 
365+         "What is a computer?" ,
366+         "Where does Santa live?" ,
367+         "Who is Abraham Lincoln?" ,
368+         "How are models trained?" ,
371369    ]
372-     """ 
373370
374371    start_pos  =  0 
375372
376-     # pipeline comms setup 
377-     first_pp_rank  =  0 
378-     last_pp_rank  =  pp_group_size  -  1 
379- 
380373    # Need these global ids due to the API definition of dist.send and recv 
381374    first_pp_rank_global_id  =  dist .get_global_rank (pp_group , first_pp_rank )
382375    last_pp_rank_global_id  =  dist .get_global_rank (pp_group , last_pp_rank )
@@ -388,15 +381,14 @@ def main(args):
388381
389382    # create a padded tensor for the input prompt 
390383    padded_sequence , prompt_lengths  =  _create_padded_prompts (
391-         input_ids , tokenizer , seqlen , start_pos , device 
384+         input_ids , tokenizer , seqlen_prefill , start_pos , device 
392385    )
393- 
394-     # create schedule 
395-     schedule  =  ScheduleGPipe (stage , mbs )
386+     # TODO: figure out how to set input_pos for each prompt in the batch then we 
387+     # can remove this limitation. 
388+     s  =  set (prompt_lengths )
389+     assert  len (s ) ==  1 , f"prompt_lengths should be the same, got { s }  
396390
397391    # with CUDATrackTime() as timer: 
398-     first_pp_rank  =  0 
399-     last_pp_rank  =  pp_group_size  -  1 
400392    # Need these global ids due to the API definition of dist.send and recv 
401393    first_pp_rank_global_id  =  dist .get_global_rank (pp_group , first_pp_rank )
402394    last_pp_rank_global_id  =  dist .get_global_rank (pp_group , last_pp_rank )
@@ -408,25 +400,87 @@ def main(args):
408400    res  =  [[] for  _  in  range (total_prompts )]
409401    num_tokens  =  40 
410402
403+     # Prefill phase 
404+     # Run context input through pipeline, in 1 step 
405+     with  torch .no_grad ():
406+         if  pp_rank  ==  first_pp_rank :
407+             output  =  prefill_schedule .step (padded_sequence )
408+         elif  pp_rank  ==  last_pp_rank :
409+             output  =  prefill_schedule .step ()
410+         else :  # middle pp ranks 
411+             prefill_schedule .step ()
412+ 
413+     # Decode the output -- first generated token 
414+     if  pp_rank  ==  last_pp_rank :
415+         decode_results  =  _batch_decode_next_tokens (
416+             output = output ,
417+             tokenizer = tokenizer ,
418+             prompt_lengths = prompt_lengths ,
419+         )
420+         for  i  in  range (len (decode_results )):
421+             new_token [i , 0 ] =  torch .tensor (
422+                 [decode_results [i ][0 ]], device = device 
423+             )  # token_id in int form 
424+         if  tp_rank  ==  0 :
425+             logger .info (
426+                 f"{ color .green } { '* Prefill *' }  
427+                 f"responses ====>>>> { color .blue } { decode_results = } { color .reset }  
428+             )
429+ 
430+     # seqlen = 1 now 
431+     seqlen_decode  =  1 
432+     input_pos  =  torch .tensor ([prompt_lengths [0 ]], device = device )
433+     model .setup_input_pos (input_pos )
434+ 
435+     # Create decode stage 
436+     logger .info (f"Creating pipeline stage for decode { pp_rank = } { pp_degree = }  )
437+     example_inputs , example_outputs  =  get_example_ins_outs (seqlen_decode )
438+     decode_stage  =  PipelineStage (
439+         model ,
440+         pp_rank ,
441+         pp_degree ,
442+         device ,
443+         input_args = example_inputs ,
444+         output_args = example_outputs ,
445+         group = pp_group ,
446+     )
447+     # create schedule 
448+     decode_schedule  =  ScheduleGPipe (decode_stage , mbs )
449+ 
411450    # Decoding 
412451    with  torch .no_grad ():
413-         for  step  in  range (num_tokens ):
452+         for  step  in  range (num_tokens  -  1 ):
453+             # sendrecv between last and first ranks, only if: 
454+             # first_pp_rank != last_pp_rank. 
455+             if  pp_rank  ==  last_pp_rank  and  pp_rank  !=  first_pp_rank :
456+                 dist .send (
457+                     new_token ,
458+                     dst = first_pp_rank_global_id ,
459+                     group = pp_group ,
460+                 )
461+             elif  pp_rank  ==  first_pp_rank  and  pp_rank  !=  last_pp_rank :
462+                 dist .recv (
463+                     new_token ,
464+                     src = last_pp_rank_global_id ,
465+                     group = pp_group ,
466+                 )
467+ 
414468            # Run data through pipeline 
415469            if  pp_rank  ==  first_pp_rank :
416-                 output  =  schedule .step (padded_sequence )
470+                 output  =  decode_schedule .step (new_token )
417471            elif  pp_rank  ==  last_pp_rank :
418-                 output  =  schedule .step ()
472+                 output  =  decode_schedule .step ()
419473            else :  # middle pp ranks 
420-                 schedule .step ()
474+                 decode_schedule .step ()
421475
422476            # Decode the output 
423477            if  pp_rank  ==  last_pp_rank :
424478                decode_results  =  _batch_decode_next_tokens (
425-                     output = output , prompt_lengths = prompt_lengths ,  tokenizer = tokenizer 
479+                     output = output , tokenizer = tokenizer 
426480                )
427481                if  tp_rank  ==  0 :
428482                    logger .info (
429-                         f"{ color .green } { 'Prefill'    if   step   ==   0   else   ' * Decode *'}  
483+                         f"{ color .green } { '* Decode *' }  
430484                        f"responses ====>>>> { color .blue } { decode_results = } { color .reset }  
431485                    )
432486                # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] 
@@ -436,28 +490,8 @@ def main(args):
436490                        [decode_results [i ][0 ]], device = device 
437491                    )  # decode_results[i][0] 
438492
439-             # sendrecv between last and first ranks, only if: 
440-             # first_pp_rank != last_pp_rank. 
441-             if  pp_rank  ==  last_pp_rank  and  pp_rank  !=  first_pp_rank :
442-                 dist .send (
443-                     new_token ,
444-                     dst = first_pp_rank_global_id ,
445-                     group = pp_group ,
446-                 )
447-             elif  pp_rank  ==  first_pp_rank  and  pp_rank  !=  last_pp_rank :
448-                 dist .recv (
449-                     new_token ,
450-                     src = last_pp_rank_global_id ,
451-                     group = pp_group ,
452-                 )
453- 
454-             # Update input sequence with new token 
455-             if  pp_rank  ==  first_pp_rank :
456-                 _update_padded_sequence (padded_sequence , new_token , prompt_lengths )
457- 
458-             # increment prompt lengths for next token 
459-             for  i  in  range (len (prompt_lengths )):
460-                 prompt_lengths [i ] +=  1 
493+             input_pos  +=  1 
494+             model .setup_input_pos (input_pos )
461495
462496    # Display the decoding results 
463497
0 commit comments