Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5abb444

Browse files
committed
revert prompt incrementing to pp=1 state
1 parent 39a974c commit 5abb444

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

dist_run.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _update_padded_sequence(
219219
prompt_lengths: List[int],
220220
) -> None:
221221
for i in range(len(prompt_lengths)):
222-
padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0]
222+
padded_sequence[i, prompt_lengths[i]] = new_token[i, 0]
223223
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
224224

225225

@@ -441,14 +441,14 @@ def main(args):
441441
group=pp_group,
442442
)
443443

444-
# increment prompt lengths for next token
445-
for i in range(len(prompt_lengths)):
446-
prompt_lengths[i] += 1
447-
448444
# Update input sequence with new token
449445
if pp_rank == first_pp_rank:
450446
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)
451447

448+
# increment prompt lengths for next token
449+
for i in range(len(prompt_lengths)):
450+
prompt_lengths[i] += 1
451+
452452
# Display the decoding results
453453

454454
# output formatted response via last pp group and tp rank 0

0 commit comments

Comments
 (0)