-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Implement universal batch_decode & decode_in_flight for llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) #1234
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1234
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1e2eec0 with merge base 7ad9ba2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Rebase to fix failing tochao_experimental check |
kwen2501
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the new feature! LGTM. Just some minor comments.
| ) -> torch.Tensor: | ||
| """ | ||
| Decode the next token for each prompt in the batch. | ||
| Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if torchchat's generate also have the temperature option? Shall we think about how to connect with generate in next steps?
| if temperature != 1.0: | ||
| next_token_logits = next_token_logits / temperature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we do the division unconditionally?
| return next_token | ||
| batch_size, seq_len, vocab_size = output.shape | ||
|
|
||
| if step != -1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can step == -1 be represented by pos = [] or pos = [0, 0, ...]? (saving one argument)
| top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size | ||
| top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1) | ||
| probs = torch.softmax(top_k_logits, dim=-1) | ||
| next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) | ||
| next_tokens = top_k_indices.gather( | ||
| -1, next_token_indices.unsqueeze(-1) | ||
| ).squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do you mind adding more comments here for the multinomial, gather, squeeze and unsqueeze ops?
| # New token generated each iteration | ||
| # need a row dimension for each prompt in the batch | ||
| new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) | ||
| logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for debugging only?
| if not args.disable_in_flight_decode: | ||
| decode_in_flight(new_token) | ||
| _decode_in_flight(new_token, tokenizer, tp_rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: put tp_rank to the if condition?
| # Decode the output | ||
| if pp_rank == last_pp_rank: | ||
| new_token = _batch_decode_next_tokens(output, 0) | ||
| # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove log?
| # Decode the output as comprehension instead of loop | ||
| responses = [tokenizer.decode(sequence) for sequence in res_list] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, did the previous code not work in case of variable length? Just curious.
response = tokenizer.decode(res_list)
| next_token_logits = output[:, 0, :] | ||
| else: | ||
| # get the logits for each prompt at the specified positions | ||
| next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, why "-1"?
From this function's perspective, if the caller has given the position, should it just faithfully decode that position?
(I understand that this can be run right if providing prompt_length instead of promt_length -1 at callsite.)
|
This PR:
1 - updates the batch_decode_next_tokens in a way that handles both llama2 and llama3 with their respective tokenizers. Thus we have a single universal decoding for in flight.
2 - adds a temperature option to enable non-deterministic (creative) decoding, using topk and multinomial selection.
3 - update decode_in_flight, again to be compat with both llama2 and llama3.
4 - minor tweaks to using zip for final display and move decode_in_flight to _decode_in_flight with other global functions for ease of reference.
Tested with both llama2 and llama3:
example multiprompt with llama2: