Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 39a974c

Browse files
committed
fix prompt incrementing
add formatting exception for llama2 "".res
1 parent 54eeece commit 39a974c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

dist_run.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def _update_padded_sequence(
219219
prompt_lengths: List[int],
220220
) -> None:
221221
for i in range(len(prompt_lengths)):
222-
prompt_lengths[i] += 1
223222
padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0]
224223
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
225224

@@ -427,10 +426,6 @@ def main(args):
427426
[decode_results[i][0]], device=device
428427
) # decode_results[i][0]
429428

430-
# increment prompt lengths for next token
431-
for i in range(len(prompt_lengths)):
432-
prompt_lengths[i] += 1
433-
434429
# sendrecv between last and first ranks, only if:
435430
# first_pp_rank != last_pp_rank.
436431
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
@@ -446,6 +441,10 @@ def main(args):
446441
group=pp_group,
447442
)
448443

444+
# increment prompt lengths for next token
445+
for i in range(len(prompt_lengths)):
446+
prompt_lengths[i] += 1
447+
449448
# Update input sequence with new token
450449
if pp_rank == first_pp_rank:
451450
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)

0 commit comments

Comments
 (0)