Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8cba3d1

Browse files
committed
faster batch_decode_next_tokens, add topk/temperature option
1 parent 315a023 commit 8cba3d1

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

dist_run.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -189,43 +189,50 @@ def _create_padded_prompts(
189189

190190

191191
def _batch_decode_next_tokens(
192-
output: torch.Tensor, pos: List[int], step: int = -1, temperature: float = 1.0
192+
output: torch.Tensor,
193+
pos: List[int],
194+
step: int = -1,
195+
temperature: float = 1.0,
196+
topk: int = 10,
193197
) -> torch.Tensor:
194198
"""
195-
Decode the next token for each prompt in the batch.
199+
Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
200+
196201
Args:
197202
output (torch.Tensor): The output tensor to decode.
198-
pos: the position of the `output` to decode in the sequence length dimension.
203+
pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
204+
step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
205+
temperature (float): Sampling temperature for non-deterministic decoding.
199206
200207
Returns:
201-
Decoded token ids.
208+
torch.Tensor: Decoded token ids.
202209
"""
203-
# Take the next token logits for each prompt
204-
res = []
205-
# logger.info(f"{color.green}output shape = {output.shape}{color.reset}")
206-
# logger.info(f"{color.green}pos = {pos}{color.reset}")
207210
batch_size, seq_len, vocab_size = output.shape
208211

209212
if step != -1:
210213
next_token_logits = output[:, 0, :]
211-
next_token = torch.argmax(next_token_logits, dim=-1)
212-
res.append(next_token)
213-
res = torch.stack(res, dim=0)
214-
res = res.squeeze(0)
215214
else:
216-
for i in range(batch_size):
217-
token_pos = pos[i] - 1
218-
next_token_logits = output[i, token_pos, :]
219-
220-
# Argmax (deterministic) TODO: add temperature
221-
next_token = torch.argmax(next_token_logits, dim=-1)
222-
# logger.info(f"{color.blue}next_token = {next_token}{color.reset}")
223-
res.append(next_token)
224-
# Token ids in int tensor form
225-
res = torch.stack(res, dim=0)
226-
227-
logger.info(f"{color.yellow}next_token = {res}{color.reset}")
228-
return res # next_token
215+
# get the logits for each prompt at the specified positions
216+
next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1]
217+
218+
if temperature != 1.0:
219+
next_token_logits = next_token_logits / temperature
220+
221+
# Uses top-k sampling if temperature is not 1.0, otherwise use argmax
222+
if temperature != 1.0:
223+
top_k = min(topk, vocab_size)
224+
top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
225+
probs = torch.softmax(top_k_logits, dim=-1)
226+
next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1)
227+
next_tokens = top_k_indices.gather(
228+
-1, next_token_indices.unsqueeze(-1)
229+
).squeeze(-1)
230+
else:
231+
# Argmax (deterministic)
232+
next_tokens = torch.argmax(next_token_logits, dim=-1)
233+
234+
logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
235+
return next_tokens
229236

230237

231238
def _update_padded_sequence(

0 commit comments

Comments
 (0)