@@ -202,6 +202,17 @@ def _batch_decode_next_tokens(
202202 return results
203203
204204
205+ def _update_padded_sequence (
206+ padded_sequence : torch .Tensor ,
207+ x_recv : torch .Tensor ,
208+ res ,
209+ prompt_lengths : List [int ],
210+ ) -> None :
211+ for i in range (len (prompt_lengths )):
212+ prompt_lengths [i ] += 1
213+ padded_sequence [i , prompt_lengths [i ] - 1 ] = x_recv
214+
215+
205216def _cleanup ():
206217 dist .barrier ()
207218 dist .destroy_process_group ()
@@ -378,120 +389,101 @@ def main():
378389
379390 last_global_rank = world_size - 1
380391 res = []
392+ dst = None
393+ src = None
381394
382- # if pp_rank == pp_group_size - 1:
383- # dst = dist.get_global_rank(pp_group, 0)
384- # dist.send(tensor, dst, pp_group)
385-
386- with torch .no_grad (): # .inference_mode():
387- # for _ in range(1):
388- # first
389- if pp_rank == 0 :
390- schedule .step (padded_sequence )
391- src = dist .get_global_rank (pp_group , pp_group_size - 1 )
392- dist .recv (
393- x_recv ,
394- src ,
395- group = pp_group ,
396- )
397-
398- # last
399- elif pp_rank == last_pp_group :
400- output = schedule .step ()
401- # need to decode the output
402- decode_results = _batch_decode_next_tokens (
403- output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
404- )
405-
406- logger .info (
407- f"\n \n { color .green } Prefill responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
408- )
409- next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
410- res .append (decode_results [0 ][1 ])
411- dst = dist .get_global_rank (pp_group , 0 )
412- logger .info (f"SENDING back...from { rank = } to { dst = } " )
413- logger .info (f"SENDING data { next_token .shape = } , { next_token = } " )
414-
415- dist .send (
416- next_token ,
417- dst ,
418- pp_group ,
419- )
420-
421- # middle pp ranks
422- else :
423- schedule .step ()
424-
425- def _update_padded_sequence (
426- padded_sequence : torch .Tensor ,
427- x_recv : torch .Tensor ,
428- res ,
429- prompt_lengths : List [int ],
430- ) -> None :
431- for i in range (len (prompt_lengths )):
432- prompt_lengths [i ] += 1
433- padded_sequence [i , prompt_lengths [i ] - 1 ] = x_recv
434-
435- logger .info (f"REVIEW { padded_sequence [0 ,:15 ]= } " )
436-
437- # logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
438-
439- # decoding loop
440- # append first token to the prompt from prefill
441- logger .info (f"\n pre update { padded_sequence [0 ,0 :9 ]= } " )
442- _update_padded_sequence (padded_sequence , x_recv , res , prompt_lengths )
443- logger .info (f"{ prompt_lengths = } , { padded_sequence [0 , prompt_lengths [0 ]- 1 ]= } " )
444- logger .info (f"\n post update { padded_sequence [0 ,0 :9 ]= } " )
395+ if pp_rank == last_pp_group :
396+ dst = dist .get_global_rank (pp_group , 0 )
397+ elif pp_rank == 0 :
398+ src = dist .get_global_rank (pp_group , last_pp_group )
445399
446- num_tokens = 5
400+ # Decoding
401+ num_tokens = 10
402+ """
447403 with torch.no_grad():
448- for step in range (num_tokens ):
404+ for step in range(num_tokens + 1): # +1 to include the initial prefill step
449405 if pp_rank == 0:
450- logger .info (
451- f"about to send...{ prompt_lengths = } , { padded_sequence [0 , :prompt_lengths [0 ]+ 1 ]= } "
452- )
453406 schedule.step(padded_sequence)
454-
455- src = dist .get_global_rank (pp_group , pp_group_size - 1 )
456- dist .recv (
457- x_recv ,
458- src ,
459- group = pp_group ,
460- )
407+ dist.recv(x_recv, src, group=pp_group)
461408 logger.info(f"RECEIVED {x_recv=}")
462409 assert x_recv != 128006, f"next_token is header id={x_recv}"
463410 _update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
464411 logger.info(
465- f"about to send... { prompt_lengths = } , { padded_sequence [:, prompt_lengths [0 ]- 1 ]= } "
412+ f"Updated padded seq start: {prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
466413 )
467- schedule .step (padded_sequence )
468414
469415 elif pp_rank == last_pp_group:
470416 output = schedule.step()
471- # need to decode the output
472-
473417 decode_results = _batch_decode_next_tokens(
474- output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
418+ output, prompt_lengths, tokenizer
419+ )
420+ logger.info(
421+ f"\n \n {color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n {color.reset}"
475422 )
476423
424+ next_token = torch.tensor([decode_results[0][0]], device=device)
425+ res.append(decode_results[0][1])
426+
427+ # increment prompt lengths for next token
477428 for i in range(len(prompt_lengths)):
478429 prompt_lengths[i] += 1
479430 logger.info(
480431 f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
481432 )
482433
434+ if step < num_tokens - 1:
435+ dist.send(next_token, dst, pp_group)
483436 logger.info(
484- f"\n \n { color .green } * Decode * responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
485- )
486- res .append (decode_results [0 ][1 ])
487- next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
488- dst = dist .get_global_rank (pp_group , 0 )
489- logger .info (
490- f"SENDING back...from { rank = } to { dst = } , data { next_token .shape = } , { next_token = } "
437+ f"SENDING back...from rank={pp_rank} to dst={dst}, data {next_token.shape=}, {next_token=}"
491438 )
492439 assert next_token != 128006, f"next_token is header id={next_token}"
493440
441+ else: # middle pp ranks
442+ schedule.step()
443+ """
444+ with torch .no_grad ():
445+ for step in range (num_tokens ):
446+ # first
447+ if pp_rank == 0 :
448+ schedule .step (padded_sequence )
449+ # only receive if not last step
494450 if step < num_tokens - 1 :
451+
452+ dist .recv (
453+ x_recv ,
454+ src ,
455+ group = pp_group ,
456+ )
457+ _update_padded_sequence (
458+ padded_sequence , x_recv , res , prompt_lengths
459+ )
460+
461+ # last
462+ elif pp_rank == last_pp_group :
463+ output = schedule .step ()
464+ # need to decode the output
465+ decode_results = _batch_decode_next_tokens (
466+ output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
467+ )
468+ logger .info (
469+ f"\n \n { color .green } { 'Prefill' if step == 0 else '* Decode *' } responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
470+ )
471+
472+ next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
473+ res .append (decode_results [0 ][1 ])
474+
475+ # increment prompt lengths for next token
476+ for i in range (len (prompt_lengths )):
477+ prompt_lengths [i ] += 1
478+ logger .info (
479+ f"output review { prompt_lengths [i ]= } , { padded_sequence [i , prompt_lengths [i ]- 1 ]= } "
480+ )
481+
482+ # logger.info(f"SENDING back...from {rank=} to {dst=}")
483+ # logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
484+
485+ # only send if not last step
486+ if step < (num_tokens - 1 ):
495487 dist .send (
496488 next_token ,
497489 dst ,
@@ -502,6 +494,10 @@ def _update_padded_sequence(
502494 else :
503495 schedule .step ()
504496
497+ # logger.info(f"REVIEW {padded_sequence[0,:15]=}")
498+
499+ # logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
500+
505501 # Decoding
506502 """
507503 if pp_rank == pp_degree - 1 and tp_rank == 0:
0 commit comments