@@ -393,7 +393,7 @@ def main():
393393 ],
394394 device = device ,
395395 )
396- x_recv = torch .zeros_like ( padded_sequence )
396+ x_recv = torch .zeros ( 1 , device = device , dtype = torch . int64 )
397397 logger .info (f"{ x_recv .shape = } " )
398398
399399 last_global_rank = world_size - 1
@@ -425,13 +425,13 @@ def main():
425425 logger .info (
426426 f"\n \n { color .green } Prefill responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
427427 )
428-
428+ next_token = torch . tensor ([ decode_results [ 0 ][ 0 ]], device = device )
429429 dst = dist .get_global_rank (pp_group , 0 )
430430 logger .info (f"SENDING back...from { rank = } to { dst = } " )
431- logger .info (f"{ decode_results .shape = } , { decode_results [ 0 , 4 : 8 ] = } " )
431+ logger .info (f"SENDING data { next_token .shape = } , { next_token = } " )
432432
433433 dist .send (
434- decode_results ,
434+ next_token ,
435435 dst ,
436436 pp_group ,
437437 )
@@ -451,15 +451,15 @@ def main():
451451 f"{ color .red } Success! Rank { rank } - Received output from { src } { color .reset } "
452452 )
453453 logger .info (
454- f"out of loop - Received output: { x_recv [ 0 , 4 : 8 ] = } "
454+ f"out of loop - Received output: { x_recv = } , { x_recv . shape = } , { x_recv . dtype = } "
455455 ) # {padded_sequence[4:8]=}"
456456 # {padded_sequence[0, :prompt_lengths[0]+1]=}")
457457 if rank == 1 :
458458 logger .info (
459459 f"{ color .red } Success! Received { rank } output from { src } { color .reset } "
460460 )
461461 logger .info (
462- f"out of loop Received output: { x_recv [ 0 , 4 : 8 ] = } "
462+ f"out of loop Received output: { x_recv = } "
463463 ) # {padded_sequence[0, :prompt_lengths[0]+1]=}")
464464
465465 # logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
0 commit comments