@@ -273,13 +273,11 @@ def main(args):
273273    pp_rank  =  pp_mesh .get_local_rank ()
274274    tp_group  =  tp_mesh .get_group ()
275275    pp_group  =  pp_mesh .get_group ()
276-     pp_group_size  =  pp_group .size ()
277-     tp_group_size  =  tp_group .size ()
278-     logger .info (f"{ pp_group_size = } { tp_group_size = }  )
276+     logger .info (f"{ pp_degree = } { tp_degree = }  )
279277
280278    # Convenience variables 
281279    first_pp_rank  =  0 
282-     last_pp_rank  =  pp_group_size  -  1 
280+     last_pp_rank  =  pp_degree  -  1 
283281
284282    # Assuming same number of GPUs per node 
285283    device  =  torch .device (f"cuda:{ rank  %  torch .cuda .device_count ()}  )
@@ -297,18 +295,22 @@ def main(args):
297295    if  rank  ==  0 :
298296        logger .info (f"Model: { model }  )
299297
300-     mbs   =   1    # number of micro- batches
301-     mb_size   =   4    #  micro-batch size
302-     batch_size   =   mbs   *   mb_size    # total  batch size
303- 
298+     # Batch size. Since we push  batches dynamically through the pipeline rather 
299+     # than chunking them, this is effectively  micro-batch size in pipeline 
300+     # sense. Thus it is interchangeable with micro- batch size below. 
301+      batch_size   =   4 
304302    seqlen_prefill  =  1024   # sequence length 
305303    dim  =  4096   # embedding dimension 
306304
307305    # Setup KV caches (after model distribution) 
308-     # TODO: the setting below only works for 1 micro-batch case. To support 
309-     # multiple micro-batches, we need the KV cache in the model to be aware of 
310-     # the number of micro-batches and the current micro-batch index. 
311-     model .setup_caches (mb_size , seqlen_prefill )
306+     # The number of cache lanes is the same as the maximum number of 
307+     # micro-batches that can be "in flight" in parallel -- imagine each 
308+     # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. 
309+     # When decoding is done for certain micro-batches, we can reuse the KV cache 
310+     # lanes. 
311+     # TODO: bump up the lane count 
312+     pipeline_lanes  =  1 
313+     model .setup_caches (batch_size , seqlen_prefill , cache_lanes = pipeline_lanes )
312314
313315    # Load weights 
314316    logger .info (f"Loading weights for { pp_rank = } { device = }  )
@@ -317,7 +319,7 @@ def main(args):
317319        model .to (device )
318320
319321    logger .info (
320-         f"{ color .green } { timer .get_time ()} { timer .unit } stage  { rank } { color .reset }  
322+         f"{ color .green } { timer .get_time ()} { timer .unit } rank  { rank } { color .reset }  
321323    )
322324
323325    # info on stage size and params 
@@ -330,17 +332,16 @@ def main(args):
330332
331333    # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen 
332334    input_pos  =  torch .arange (seqlen_prefill , device = device )
333-     model .setup_input_pos (input_pos )
334335    model .eval ()
335336
336337    # Helper function to get example inputs and outputs for the stages. 
337338    def  get_example_ins_outs (seqlen : int ) ->  Tuple [torch .Tensor , torch .Tensor ]:
338-         mb_ids  =  torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
339+         mb_ids  =  torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
339340        activation  =  torch .rand (
340-             mb_size , seqlen , dim , device = device , dtype = model_dtype 
341+             batch_size , seqlen , dim , device = device , dtype = model_dtype 
341342        )
342343        logits  =  torch .rand (
343-             mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype 
344+             batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype 
344345        )
345346        example_inputs  =  (mb_ids  if  pp_rank  ==  first_pp_rank  else  activation ,)
346347        example_outputs  =  (logits  if  pp_rank  ==  last_pp_rank  else  activation ,)
@@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
358359        output_args = example_outputs ,
359360        group = pp_group ,
360361    )
361-     # create schedule 
362-     prefill_schedule  =  ScheduleGPipe (prefill_stage , mbs )
362+ 
363+     # Create schedule 
364+     # Number of micro-batches for the schedule is 1, because each step() call we 
365+     # only push 1 micro-batch into the pipeline. But we can continuously push 
366+     # new micro-batches into the pipeline as they arrive, achieving same 
367+     # pipelining effect. 
368+     prefiller  =  ScheduleGPipe (prefill_stage , 1 )
363369
364370    prompt  =  [
365371        "What is a computer?" ,
@@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
388394    s  =  set (prompt_lengths )
389395    assert  len (s ) ==  1 , f"prompt_lengths should be the same, got { s }  
390396
391-     # with CUDATrackTime() as timer: 
392397    # Need these global ids due to the API definition of dist.send and recv 
393398    first_pp_rank_global_id  =  dist .get_global_rank (pp_group , first_pp_rank )
394399    last_pp_rank_global_id  =  dist .get_global_rank (pp_group , last_pp_rank )
@@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401406    num_tokens  =  40 
402407
403408    # Prefill phase 
404-     # Run context input through pipeline, in 1 step 
405-     with  torch .no_grad ():
409+     # Run context input through pipeline 
410+     # TODO: we need to pass `input_pos` and `cache_lane` to each stage. 
411+     lane  =  0 
412+     kwargs  =  {"input_pos" : input_pos , "cache_lane" : lane }
413+     with  torch .no_grad (), CUDATrackTime () as  timer :
406414        if  pp_rank  ==  first_pp_rank :
407-             output  =  prefill_schedule .step (padded_sequence )
415+             output  =  prefiller .step (padded_sequence ,  ** kwargs )
408416        elif  pp_rank  ==  last_pp_rank :
409-             output  =  prefill_schedule .step ()
417+             output  =  prefiller .step (** kwargs )
410418        else :  # middle pp ranks 
411-             prefill_schedule .step ()
419+             prefiller .step (** kwargs )
420+ 
421+     logger .info (
422+         f"{ color .green } { timer .get_time ()} { timer .unit } { rank } { color .reset }  
423+     )
412424
413425    # Decode the output -- first generated token 
414426    if  pp_rank  ==  last_pp_rank :
@@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
430442    # seqlen = 1 now 
431443    seqlen_decode  =  1 
432444    input_pos  =  torch .tensor ([prompt_lengths [0 ]], device = device )
433-     model .setup_input_pos (input_pos )
434445
435446    # Create decode stage 
436447    logger .info (f"Creating pipeline stage for decode { pp_rank = } { pp_degree = }  )
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
445456        group = pp_group ,
446457    )
447458    # create schedule 
448-     decode_schedule  =  ScheduleGPipe (decode_stage , mbs )
459+     decorder  =  ScheduleGPipe (decode_stage , 1 )
449460
450461    # Decoding 
451-     with  torch .no_grad ():
462+     with  torch .no_grad (),  CUDATrackTime ()  as   timer :
452463        for  step  in  range (num_tokens  -  1 ):
464+             kwargs  =  {"input_pos" : input_pos , "cache_lane" : lane }
453465            # sendrecv between last and first ranks, only if: 
454466            # first_pp_rank != last_pp_rank. 
455467            if  pp_rank  ==  last_pp_rank  and  pp_rank  !=  first_pp_rank :
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
467479
468480            # Run data through pipeline 
469481            if  pp_rank  ==  first_pp_rank :
470-                 output  =  decode_schedule .step (new_token )
482+                 output  =  decorder .step (new_token ,  ** kwargs )
471483            elif  pp_rank  ==  last_pp_rank :
472-                 output  =  decode_schedule .step ()
484+                 output  =  decorder .step (** kwargs )
473485            else :  # middle pp ranks 
474-                 decode_schedule .step ()
486+                 decorder .step (** kwargs )
475487
476488            # Decode the output 
477489            if  pp_rank  ==  last_pp_rank :
@@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
491503                    )  # decode_results[i][0] 
492504
493505            input_pos  +=  1 
494-             model .setup_input_pos (input_pos )
506+ 
507+     logger .info (
508+         f"{ color .green } { timer .get_time ()} { timer .unit } { rank } { color .reset }  
509+     )
495510
496511    # Display the decoding results 
497512
0 commit comments