@@ -273,13 +273,11 @@ def main(args):
273273 pp_rank = pp_mesh .get_local_rank ()
274274 tp_group = tp_mesh .get_group ()
275275 pp_group = pp_mesh .get_group ()
276- pp_group_size = pp_group .size ()
277- tp_group_size = tp_group .size ()
278- logger .info (f"{ pp_group_size = } , { tp_group_size = } " )
276+ logger .info (f"{ pp_degree = } , { tp_degree = } " )
279277
280278 # Convenience variables
281279 first_pp_rank = 0
282- last_pp_rank = pp_group_size - 1
280+ last_pp_rank = pp_degree - 1
283281
284282 # Assuming same number of GPUs per node
285283 device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
@@ -297,18 +295,22 @@ def main(args):
297295 if rank == 0 :
298296 logger .info (f"Model: { model } " )
299297
300- mbs = 1 # number of micro- batches
301- mb_size = 4 # micro-batch size
302- batch_size = mbs * mb_size # total batch size
303-
298+ # Batch size. Since we push batches dynamically through the pipeline rather
299+ # than chunking them, this is effectively micro-batch size in pipeline
300+ # sense. Thus it is interchangeable with micro- batch size below.
301+ batch_size = 4
304302 seqlen_prefill = 1024 # sequence length
305303 dim = 4096 # embedding dimension
306304
307305 # Setup KV caches (after model distribution)
308- # TODO: the setting below only works for 1 micro-batch case. To support
309- # multiple micro-batches, we need the KV cache in the model to be aware of
310- # the number of micro-batches and the current micro-batch index.
311- model .setup_caches (mb_size , seqlen_prefill )
306+ # The number of cache lanes is the same as the maximum number of
307+ # micro-batches that can be "in flight" in parallel -- imagine each
308+ # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
309+ # When decoding is done for certain micro-batches, we can reuse the KV cache
310+ # lanes.
311+ # TODO: bump up the lane count
312+ pipeline_lanes = 1
313+ model .setup_caches (batch_size , seqlen_prefill , cache_lanes = pipeline_lanes )
312314
313315 # Load weights
314316 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
@@ -317,7 +319,7 @@ def main(args):
317319 model .to (device )
318320
319321 logger .info (
320- f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
322+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
321323 )
322324
323325 # info on stage size and params
@@ -330,17 +332,16 @@ def main(args):
330332
331333 # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
332334 input_pos = torch .arange (seqlen_prefill , device = device )
333- model .setup_input_pos (input_pos )
334335 model .eval ()
335336
336337 # Helper function to get example inputs and outputs for the stages.
337338 def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
338- mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
339+ mb_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
339340 activation = torch .rand (
340- mb_size , seqlen , dim , device = device , dtype = model_dtype
341+ batch_size , seqlen , dim , device = device , dtype = model_dtype
341342 )
342343 logits = torch .rand (
343- mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
344+ batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
344345 )
345346 example_inputs = (mb_ids if pp_rank == first_pp_rank else activation ,)
346347 example_outputs = (logits if pp_rank == last_pp_rank else activation ,)
@@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
358359 output_args = example_outputs ,
359360 group = pp_group ,
360361 )
361- # create schedule
362- prefill_schedule = ScheduleGPipe (prefill_stage , mbs )
362+
363+ # Create schedule
364+ # Number of micro-batches for the schedule is 1, because each step() call we
365+ # only push 1 micro-batch into the pipeline. But we can continuously push
366+ # new micro-batches into the pipeline as they arrive, achieving same
367+ # pipelining effect.
368+ prefiller = ScheduleGPipe (prefill_stage , 1 )
363369
364370 prompt = [
365371 "What is a computer?" ,
@@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
388394 s = set (prompt_lengths )
389395 assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
390396
391- # with CUDATrackTime() as timer:
392397 # Need these global ids due to the API definition of dist.send and recv
393398 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
394399 last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
@@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401406 num_tokens = 40
402407
403408 # Prefill phase
404- # Run context input through pipeline, in 1 step
405- with torch .no_grad ():
409+ # Run context input through pipeline
410+ # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
411+ lane = 0
412+ kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
413+ with torch .no_grad (), CUDATrackTime () as timer :
406414 if pp_rank == first_pp_rank :
407- output = prefill_schedule .step (padded_sequence )
415+ output = prefiller .step (padded_sequence , ** kwargs )
408416 elif pp_rank == last_pp_rank :
409- output = prefill_schedule .step ()
417+ output = prefiller .step (** kwargs )
410418 else : # middle pp ranks
411- prefill_schedule .step ()
419+ prefiller .step (** kwargs )
420+
421+ logger .info (
422+ f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
423+ )
412424
413425 # Decode the output -- first generated token
414426 if pp_rank == last_pp_rank :
@@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
430442 # seqlen = 1 now
431443 seqlen_decode = 1
432444 input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
433- model .setup_input_pos (input_pos )
434445
435446 # Create decode stage
436447 logger .info (f"Creating pipeline stage for decode { pp_rank = } , { pp_degree = } " )
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
445456 group = pp_group ,
446457 )
447458 # create schedule
448- decode_schedule = ScheduleGPipe (decode_stage , mbs )
459+ decorder = ScheduleGPipe (decode_stage , 1 )
449460
450461 # Decoding
451- with torch .no_grad ():
462+ with torch .no_grad (), CUDATrackTime () as timer :
452463 for step in range (num_tokens - 1 ):
464+ kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
453465 # sendrecv between last and first ranks, only if:
454466 # first_pp_rank != last_pp_rank.
455467 if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
467479
468480 # Run data through pipeline
469481 if pp_rank == first_pp_rank :
470- output = decode_schedule .step (new_token )
482+ output = decorder .step (new_token , ** kwargs )
471483 elif pp_rank == last_pp_rank :
472- output = decode_schedule .step ()
484+ output = decorder .step (** kwargs )
473485 else : # middle pp ranks
474- decode_schedule .step ()
486+ decorder .step (** kwargs )
475487
476488 # Decode the output
477489 if pp_rank == last_pp_rank :
@@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
491503 ) # decode_results[i][0]
492504
493505 input_pos += 1
494- model .setup_input_pos (input_pos )
506+
507+ logger .info (
508+ f"{ color .green } Decoding time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
509+ )
495510
496511 # Display the decoding results
497512
0 commit comments