Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e3fe1bf

Browse files
committed
second decoded token
1 parent 13fbee6 commit e3fe1bf

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

dist_run.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def main():
407407
f"\n\n{color.green} Prefill responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
408408
)
409409
next_token = torch.tensor([decode_results[0][0]], device=device)
410+
res.append(decode_results[0][1])
410411
dst = dist.get_global_rank(pp_group, 0)
411412
logger.info(f"SENDING back...from {rank=} to {dst=}")
412413
logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
@@ -431,57 +432,71 @@ def _update_padded_sequence(
431432
prompt_lengths[i] += 1
432433
padded_sequence[i, prompt_lengths[i] - 1] = x_recv
433434

434-
logger.info(f"REVIEW {padded_sequence[0,4:9]=}")
435+
logger.info(f"REVIEW {padded_sequence[0,:15]=}")
435436

436437
# logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
437438

438439
# decoding loop
439440
# append first token to the prompt from prefill
440-
logger.info(f"{prompt_lengths=}")
441-
logger.info(f"{prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}")
442-
prompt_lengths[0] += 1
443-
padded_sequence[0, prompt_lengths[0] - 1] = x_recv
444-
logger.info(f"{padded_sequence[0, prompt_lengths[0]+1]=}")
441+
logger.info(f"\npre update {padded_sequence[0,0:9]=}")
442+
_update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
443+
logger.info(f"{prompt_lengths=}, {padded_sequence[0, prompt_lengths[0]-1]=}")
444+
logger.info(f"\n post update {padded_sequence[0,0:9]=}")
445445

446-
num_tokens = 4
446+
num_tokens = 5
447447
with torch.no_grad():
448-
for _ in range(num_tokens):
448+
for step in range(num_tokens):
449449
if pp_rank == 0:
450+
logger.info(
451+
f"about to send...{prompt_lengths=}, {padded_sequence[0, :prompt_lengths[0]+1]=}"
452+
)
450453
schedule.step(padded_sequence)
454+
451455
src = dist.get_global_rank(pp_group, pp_group_size - 1)
452456
dist.recv(
453457
x_recv,
454458
src,
455459
group=pp_group,
456460
)
457461
logger.info(f"RECEIVED {x_recv=}")
462+
assert x_recv != 128006, f"next_token is header id={x_recv}"
458463
_update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
459464
logger.info(
460-
f"{prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
465+
f"about to send...{prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
461466
)
462467
schedule.step(padded_sequence)
463468

464469
elif pp_rank == last_pp_group:
465470
output = schedule.step()
466471
# need to decode the output
472+
467473
decode_results = _batch_decode_next_tokens(
468474
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
469475
)
470476

477+
for i in range(len(prompt_lengths)):
478+
prompt_lengths[i] += 1
479+
logger.info(
480+
f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
481+
)
482+
471483
logger.info(
472484
f"\n\n{color.green} * Decode * responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
473485
)
474486
res.append(decode_results[0][1])
475487
next_token = torch.tensor([decode_results[0][0]], device=device)
476488
dst = dist.get_global_rank(pp_group, 0)
477-
logger.info(f"SENDING back...from {rank=} to {dst=}")
478-
logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
479-
480-
dist.send(
481-
next_token,
482-
dst,
483-
pp_group,
489+
logger.info(
490+
f"SENDING back...from {rank=} to {dst=}, data {next_token.shape=}, {next_token=}"
484491
)
492+
assert next_token != 128006, f"next_token is header id={next_token}"
493+
494+
if step < num_tokens - 1:
495+
dist.send(
496+
next_token,
497+
dst,
498+
pp_group,
499+
)
485500

486501
# middle pp ranks
487502
else:

0 commit comments

Comments
 (0)