Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a512141

Browse files
committed
improve batch_decode_next_tokens
1 parent 1c7368f commit a512141

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

dist_run.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _batch_decode_next_tokens_new(
226226

227227

228228
def _batch_decode_next_tokens(
229-
output: torch.Tensor, pos: int, step: int = -1
229+
output: torch.Tensor, pos: List[int], step: int = -1, temperature: float = 1.0
230230
) -> torch.Tensor:
231231
"""
232232
Decode the next token for each prompt in the batch.
@@ -239,19 +239,29 @@ def _batch_decode_next_tokens(
239239
"""
240240
# Take the next token logits for each prompt
241241
res = []
242-
logger.info(f"{color.green}output shape = {output.shape}{color.reset}")
243-
logger.info(f"{color.green}pos = {pos}{color.reset}")
244-
for i in range(output.shape[0]):
245-
token_pos = 0 if step != -1 else pos[i] - 1
246-
next_token_logits = output[i, token_pos, :]
242+
# logger.info(f"{color.green}output shape = {output.shape}{color.reset}")
243+
# logger.info(f"{color.green}pos = {pos}{color.reset}")
244+
batch_size, seq_len, vocab_size = output.shape
247245

248-
# Argmax (deterministic) TODO: add temperature
246+
if step != -1:
247+
next_token_logits = output[:, 0, :]
249248
next_token = torch.argmax(next_token_logits, dim=-1)
250-
logger.info(f"{color.blue}next_token = {next_token}{color.reset}")
251249
res.append(next_token)
252-
# Token ids in int tensor form
253-
res = torch.stack(res, dim=0)
254-
logger.info(f"{color.green}next_token = {res}{color.reset}")
250+
res = torch.stack(res, dim=0)
251+
res = res.squeeze(0)
252+
else:
253+
for i in range(batch_size):
254+
token_pos = pos[i] - 1
255+
next_token_logits = output[i, token_pos, :]
256+
257+
# Argmax (deterministic) TODO: add temperature
258+
next_token = torch.argmax(next_token_logits, dim=-1)
259+
# logger.info(f"{color.blue}next_token = {next_token}{color.reset}")
260+
res.append(next_token)
261+
# Token ids in int tensor form
262+
res = torch.stack(res, dim=0)
263+
264+
logger.info(f"{color.yellow}next_token = {res}{color.reset}")
255265
return res # next_token
256266

257267

@@ -340,7 +350,7 @@ def main(args):
340350
# Batch size. Since we push batches dynamically through the pipeline rather
341351
# than chunking them, this is effectively micro-batch size in pipeline
342352
# sense. Thus it is interchangeable with micro-batch size below.
343-
batch_size = 2
353+
batch_size = 3
344354
seqlen_prefill = 1024 # sequence length
345355
dim = 4096 # embedding dimension
346356

@@ -414,7 +424,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
414424
prompt = [
415425
"What is Snow?",
416426
"Who is Santa Claus?",
417-
# "Where does Santa live?",
427+
"Where does Santa live?",
418428
# "Who is Abraham Lincoln?",
419429
# "How are models trained?",
420430
]
@@ -455,7 +465,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
455465
# Run context input through pipeline
456466
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
457467
lane = 0
458-
logger.info(f"{color.green}Prefilling...{input_pos=}{color.reset}")
459468
kwargs = {"input_pos": input_pos, "cache_lane": lane}
460469
with torch.no_grad(), CUDATrackTime() as timer:
461470
if pp_rank == first_pp_rank:

0 commit comments

Comments
 (0)