Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c9e6152

Browse files
committed
Replace total_prompts with batch_size
1 parent 9514b54 commit c9e6152

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

dist_run.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
399399
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
400400

401401
# New token generated each iteration
402-
total_prompts = len(prompt_lengths)
403-
# need a new token dimension (row) for each prompt in the batch
404-
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
405-
res = [[] for _ in range(total_prompts)]
402+
# need a row dimension for each prompt in the batch
403+
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
404+
res = [[] for _ in range(batch_size)]
406405
num_tokens = 40
407406

408407
# Prefill phase

0 commit comments

Comments
 (0)