Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit fe9dae9

Browse files
authored
single prompt prefill + decoding all working
1 parent e3fe1bf commit fe9dae9

File tree

1 file changed

+83
-87
lines changed

1 file changed

+83
-87
lines changed

dist_run.py

Lines changed: 83 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,17 @@ def _batch_decode_next_tokens(
202202
return results
203203

204204

205+
def _update_padded_sequence(
206+
padded_sequence: torch.Tensor,
207+
x_recv: torch.Tensor,
208+
res,
209+
prompt_lengths: List[int],
210+
) -> None:
211+
for i in range(len(prompt_lengths)):
212+
prompt_lengths[i] += 1
213+
padded_sequence[i, prompt_lengths[i] - 1] = x_recv
214+
215+
205216
def _cleanup():
206217
dist.barrier()
207218
dist.destroy_process_group()
@@ -378,120 +389,101 @@ def main():
378389

379390
last_global_rank = world_size - 1
380391
res = []
392+
dst = None
393+
src = None
381394

382-
# if pp_rank == pp_group_size - 1:
383-
# dst = dist.get_global_rank(pp_group, 0)
384-
# dist.send(tensor, dst, pp_group)
385-
386-
with torch.no_grad(): # .inference_mode():
387-
# for _ in range(1):
388-
# first
389-
if pp_rank == 0:
390-
schedule.step(padded_sequence)
391-
src = dist.get_global_rank(pp_group, pp_group_size - 1)
392-
dist.recv(
393-
x_recv,
394-
src,
395-
group=pp_group,
396-
)
397-
398-
# last
399-
elif pp_rank == last_pp_group:
400-
output = schedule.step()
401-
# need to decode the output
402-
decode_results = _batch_decode_next_tokens(
403-
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
404-
)
405-
406-
logger.info(
407-
f"\n\n{color.green} Prefill responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
408-
)
409-
next_token = torch.tensor([decode_results[0][0]], device=device)
410-
res.append(decode_results[0][1])
411-
dst = dist.get_global_rank(pp_group, 0)
412-
logger.info(f"SENDING back...from {rank=} to {dst=}")
413-
logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
414-
415-
dist.send(
416-
next_token,
417-
dst,
418-
pp_group,
419-
)
420-
421-
# middle pp ranks
422-
else:
423-
schedule.step()
424-
425-
def _update_padded_sequence(
426-
padded_sequence: torch.Tensor,
427-
x_recv: torch.Tensor,
428-
res,
429-
prompt_lengths: List[int],
430-
) -> None:
431-
for i in range(len(prompt_lengths)):
432-
prompt_lengths[i] += 1
433-
padded_sequence[i, prompt_lengths[i] - 1] = x_recv
434-
435-
logger.info(f"REVIEW {padded_sequence[0,:15]=}")
436-
437-
# logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
438-
439-
# decoding loop
440-
# append first token to the prompt from prefill
441-
logger.info(f"\npre update {padded_sequence[0,0:9]=}")
442-
_update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
443-
logger.info(f"{prompt_lengths=}, {padded_sequence[0, prompt_lengths[0]-1]=}")
444-
logger.info(f"\n post update {padded_sequence[0,0:9]=}")
395+
if pp_rank == last_pp_group:
396+
dst = dist.get_global_rank(pp_group, 0)
397+
elif pp_rank == 0:
398+
src = dist.get_global_rank(pp_group, last_pp_group)
445399

446-
num_tokens = 5
400+
# Decoding
401+
num_tokens = 10
402+
"""
447403
with torch.no_grad():
448-
for step in range(num_tokens):
404+
for step in range(num_tokens + 1): # +1 to include the initial prefill step
449405
if pp_rank == 0:
450-
logger.info(
451-
f"about to send...{prompt_lengths=}, {padded_sequence[0, :prompt_lengths[0]+1]=}"
452-
)
453406
schedule.step(padded_sequence)
454-
455-
src = dist.get_global_rank(pp_group, pp_group_size - 1)
456-
dist.recv(
457-
x_recv,
458-
src,
459-
group=pp_group,
460-
)
407+
dist.recv(x_recv, src, group=pp_group)
461408
logger.info(f"RECEIVED {x_recv=}")
462409
assert x_recv != 128006, f"next_token is header id={x_recv}"
463410
_update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
464411
logger.info(
465-
f"about to send...{prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
412+
f"Updated padded seq start: {prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
466413
)
467-
schedule.step(padded_sequence)
468414
469415
elif pp_rank == last_pp_group:
470416
output = schedule.step()
471-
# need to decode the output
472-
473417
decode_results = _batch_decode_next_tokens(
474-
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
418+
output, prompt_lengths, tokenizer
419+
)
420+
logger.info(
421+
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
475422
)
476423
424+
next_token = torch.tensor([decode_results[0][0]], device=device)
425+
res.append(decode_results[0][1])
426+
427+
# increment prompt lengths for next token
477428
for i in range(len(prompt_lengths)):
478429
prompt_lengths[i] += 1
479430
logger.info(
480431
f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
481432
)
482433
434+
if step < num_tokens - 1:
435+
dist.send(next_token, dst, pp_group)
483436
logger.info(
484-
f"\n\n{color.green} * Decode * responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
485-
)
486-
res.append(decode_results[0][1])
487-
next_token = torch.tensor([decode_results[0][0]], device=device)
488-
dst = dist.get_global_rank(pp_group, 0)
489-
logger.info(
490-
f"SENDING back...from {rank=} to {dst=}, data {next_token.shape=}, {next_token=}"
437+
f"SENDING back...from rank={pp_rank} to dst={dst}, data {next_token.shape=}, {next_token=}"
491438
)
492439
assert next_token != 128006, f"next_token is header id={next_token}"
493440
441+
else: # middle pp ranks
442+
schedule.step()
443+
"""
444+
with torch.no_grad():
445+
for step in range(num_tokens):
446+
# first
447+
if pp_rank == 0:
448+
schedule.step(padded_sequence)
449+
# only receive if not last step
494450
if step < num_tokens - 1:
451+
452+
dist.recv(
453+
x_recv,
454+
src,
455+
group=pp_group,
456+
)
457+
_update_padded_sequence(
458+
padded_sequence, x_recv, res, prompt_lengths
459+
)
460+
461+
# last
462+
elif pp_rank == last_pp_group:
463+
output = schedule.step()
464+
# need to decode the output
465+
decode_results = _batch_decode_next_tokens(
466+
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
467+
)
468+
logger.info(
469+
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
470+
)
471+
472+
next_token = torch.tensor([decode_results[0][0]], device=device)
473+
res.append(decode_results[0][1])
474+
475+
# increment prompt lengths for next token
476+
for i in range(len(prompt_lengths)):
477+
prompt_lengths[i] += 1
478+
logger.info(
479+
f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
480+
)
481+
482+
# logger.info(f"SENDING back...from {rank=} to {dst=}")
483+
# logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
484+
485+
# only send if not last step
486+
if step < (num_tokens - 1):
495487
dist.send(
496488
next_token,
497489
dst,
@@ -502,6 +494,10 @@ def _update_padded_sequence(
502494
else:
503495
schedule.step()
504496

497+
# logger.info(f"REVIEW {padded_sequence[0,:15]=}")
498+
499+
# logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
500+
505501
# Decoding
506502
"""
507503
if pp_rank == pp_degree - 1 and tp_rank == 0:

0 commit comments

Comments
 (0)