@@ -407,6 +407,7 @@ def main():
407407 f"\n \n { color .green } Prefill responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
408408 )
409409 next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
410+ res .append (decode_results [0 ][1 ])
410411 dst = dist .get_global_rank (pp_group , 0 )
411412 logger .info (f"SENDING back...from { rank = } to { dst = } " )
412413 logger .info (f"SENDING data { next_token .shape = } , { next_token = } " )
@@ -431,57 +432,71 @@ def _update_padded_sequence(
431432 prompt_lengths [i ] += 1
432433 padded_sequence [i , prompt_lengths [i ] - 1 ] = x_recv
433434
434- logger .info (f"REVIEW { padded_sequence [0 ,4 : 9 ]= } " )
435+ logger .info (f"REVIEW { padded_sequence [0 ,: 15 ]= } " )
435436
436437 # logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
437438
438439 # decoding loop
439440 # append first token to the prompt from prefill
440- logger .info (f"{ prompt_lengths = } " )
441- logger .info (f"{ prompt_lengths = } , { padded_sequence [:, prompt_lengths [0 ]- 1 ]= } " )
442- prompt_lengths [0 ] += 1
443- padded_sequence [0 , prompt_lengths [0 ] - 1 ] = x_recv
444- logger .info (f"{ padded_sequence [0 , prompt_lengths [0 ]+ 1 ]= } " )
441+ logger .info (f"\n pre update { padded_sequence [0 ,0 :9 ]= } " )
442+ _update_padded_sequence (padded_sequence , x_recv , res , prompt_lengths )
443+ logger .info (f"{ prompt_lengths = } , { padded_sequence [0 , prompt_lengths [0 ]- 1 ]= } " )
444+ logger .info (f"\n post update { padded_sequence [0 ,0 :9 ]= } " )
445445
446- num_tokens = 4
446+ num_tokens = 5
447447 with torch .no_grad ():
448- for _ in range (num_tokens ):
448+ for step in range (num_tokens ):
449449 if pp_rank == 0 :
450+ logger .info (
451+ f"about to send...{ prompt_lengths = } , { padded_sequence [0 , :prompt_lengths [0 ]+ 1 ]= } "
452+ )
450453 schedule .step (padded_sequence )
454+
451455 src = dist .get_global_rank (pp_group , pp_group_size - 1 )
452456 dist .recv (
453457 x_recv ,
454458 src ,
455459 group = pp_group ,
456460 )
457461 logger .info (f"RECEIVED { x_recv = } " )
462+ assert x_recv != 128006 , f"next_token is header id={ x_recv } "
458463 _update_padded_sequence (padded_sequence , x_recv , res , prompt_lengths )
459464 logger .info (
460- f"{ prompt_lengths = } , { padded_sequence [:, prompt_lengths [0 ]- 1 ]= } "
465+ f"about to send... { prompt_lengths = } , { padded_sequence [:, prompt_lengths [0 ]- 1 ]= } "
461466 )
462467 schedule .step (padded_sequence )
463468
464469 elif pp_rank == last_pp_group :
465470 output = schedule .step ()
466471 # need to decode the output
472+
467473 decode_results = _batch_decode_next_tokens (
468474 output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
469475 )
470476
477+ for i in range (len (prompt_lengths )):
478+ prompt_lengths [i ] += 1
479+ logger .info (
480+ f"output review { prompt_lengths [i ]= } , { padded_sequence [i , prompt_lengths [i ]- 1 ]= } "
481+ )
482+
471483 logger .info (
472484 f"\n \n { color .green } * Decode * responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
473485 )
474486 res .append (decode_results [0 ][1 ])
475487 next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
476488 dst = dist .get_global_rank (pp_group , 0 )
477- logger .info (f"SENDING back...from { rank = } to { dst = } " )
478- logger .info (f"SENDING data { next_token .shape = } , { next_token = } " )
479-
480- dist .send (
481- next_token ,
482- dst ,
483- pp_group ,
489+ logger .info (
490+ f"SENDING back...from { rank = } to { dst = } , data { next_token .shape = } , { next_token = } "
484491 )
492+ assert next_token != 128006 , f"next_token is header id={ next_token } "
493+
494+ if step < num_tokens - 1 :
495+ dist .send (
496+ next_token ,
497+ dst ,
498+ pp_group ,
499+ )
485500
486501 # middle pp ranks
487502 else :
0 commit comments