2929 get_module_size ,
3030 get_num_params ,
3131 GPUMemoryMonitor ,
32- TrackTime ,
3332)
3433from distributed .verification_utils import find_cpu_tensors
3534from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
@@ -124,7 +123,7 @@ def _encode_strings(
124123 strings : List [str ],
125124 tokenizer ,
126125 bos : bool = True ,
127- device : str = "cuda" ,
126+ device : torch . device = "cuda:0 " ,
128127 dtype = torch .int64 ,
129128) -> List [torch .Tensor ]:
130129 """Encode a list of prompt strings into a list of tensor token ids."""
@@ -142,7 +141,7 @@ def _create_padded_prompts(
142141 tokenizer ,
143142 seqlen : int ,
144143 start_pos : int ,
145- device : str ,
144+ device : torch . device ,
146145 pad_token_id : Optional [int ] = None ,
147146) -> Tuple [torch .Tensor , List [int ]]:
148147 """
@@ -284,7 +283,7 @@ def main():
284283
285284 # Distribute model on TP mesh
286285 model .distribute (tp_mesh )
287- # logger.info(f"Model: {model}")
286+ logger .info (f"Model: { model } " )
288287
289288 mbs = 1 # number of micro-batches
290289 mb_size = 1 # micro-batch size
@@ -302,7 +301,7 @@ def main():
302301
303302 # Load weights
304303 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
305- with TrackTime () as timer :
304+ with CUDATrackTime () as timer :
306305 _load_model_weights (model , hf_model_name , device = device , model_config = config )
307306 logger .info (
308307 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
@@ -316,7 +315,7 @@ def main():
316315 f"Stage { rank } has { color .blue } { stage_num_params } params{ color .reset } , Size: { color .blue } { stage_size_formatted } { color .reset } \n "
317316 )
318317
319- # Setup input position
318+ # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
320319 input_pos = torch .arange (seqlen , device = device )
321320 model .setup_input_pos (input_pos )
322321 model .eval ()
@@ -398,57 +397,15 @@ def main():
398397 src = dist .get_global_rank (pp_group , last_pp_group )
399398
400399 # Decoding
401- num_tokens = 10
402- """
403- with torch.no_grad():
404- for step in range(num_tokens + 1): # +1 to include the initial prefill step
405- if pp_rank == 0:
406- schedule.step(padded_sequence)
407- dist.recv(x_recv, src, group=pp_group)
408- logger.info(f"RECEIVED {x_recv=}")
409- assert x_recv != 128006, f"next_token is header id={x_recv}"
410- _update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
411- logger.info(
412- f"Updated padded seq start: {prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
413- )
414-
415- elif pp_rank == last_pp_group:
416- output = schedule.step()
417- decode_results = _batch_decode_next_tokens(
418- output, prompt_lengths, tokenizer
419- )
420- logger.info(
421- f"\n \n {color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n {color.reset}"
422- )
423-
424- next_token = torch.tensor([decode_results[0][0]], device=device)
425- res.append(decode_results[0][1])
400+ num_tokens = 40
426401
427- # increment prompt lengths for next token
428- for i in range(len(prompt_lengths)):
429- prompt_lengths[i] += 1
430- logger.info(
431- f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
432- )
433-
434- if step < num_tokens - 1:
435- dist.send(next_token, dst, pp_group)
436- logger.info(
437- f"SENDING back...from rank={pp_rank} to dst={dst}, data {next_token.shape=}, {next_token=}"
438- )
439- assert next_token != 128006, f"next_token is header id={next_token}"
440-
441- else: # middle pp ranks
442- schedule.step()
443- """
444402 with torch .no_grad ():
445403 for step in range (num_tokens ):
446404 # first
447405 if pp_rank == 0 :
448406 schedule .step (padded_sequence )
449407 # only receive if not last step
450408 if step < num_tokens - 1 :
451-
452409 dist .recv (
453410 x_recv ,
454411 src ,
@@ -465,9 +422,10 @@ def main():
465422 decode_results = _batch_decode_next_tokens (
466423 output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
467424 )
468- logger .info (
469- f"\n \n { color .green } { 'Prefill' if step == 0 else '* Decode *' } responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
470- )
425+ if tp_rank == 0 :
426+ logger .info (
427+ f"\n \n { color .green } { 'Prefill' if step == 0 else '* Decode *' } responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
428+ )
471429
472430 next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
473431 res .append (decode_results [0 ][1 ])
@@ -479,9 +437,6 @@ def main():
479437 f"output review { prompt_lengths [i ]= } , { padded_sequence [i , prompt_lengths [i ]- 1 ]= } "
480438 )
481439
482- # logger.info(f"SENDING back...from {rank=} to {dst=}")
483- # logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
484-
485440 # only send if not last step
486441 if step < (num_tokens - 1 ):
487442 dist .send (
@@ -496,11 +451,10 @@ def main():
496451
497452 # output formatted response via last pp group and tp rank 0
498453 if pp_rank == last_pp_group and tp_rank == 0 :
499- logger .info (f"Prompt :{ color .green } { prompt [0 ]} { color .reset } " )
454+ logger .info (f"\n Prompt :{ color .green } { prompt [0 ]} { color .reset } " )
500455 formatted_response = "" .join (res )
501- logger .info (f"$$$$$$ { color .blue } { formatted_response } { color .reset } $$$$$" )
456+ logger .info (f"$$$$$$ { color .blue } { formatted_response } \n { color .reset } $$$$$" )
502457
503- logger .info (f"$$$$$$ { color .red } { res = } { color .reset } $$$$$" )
504458 logger .info (
505459 f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } "
506460 )
0 commit comments