-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Implement universal batch_decode & decode_in_flight for llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) #1234
Changes from all commits
f4cdbf8
4195913
1c7368f
a512141
41d61a8
315a023
8cba3d1
844e908
3b550f1
285860e
1e2eec0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,8 @@ | |
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.pipelining import PipelineStage, ScheduleGPipe | ||
| from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs | ||
|
|
||
| from torchchat.distributed.logging_utils import SingletonLogger | ||
|
|
||
|
|
@@ -33,8 +35,6 @@ | |
| get_num_params, | ||
| GPUMemoryMonitor, | ||
| ) | ||
| from torch.distributed.pipelining import PipelineStage, ScheduleGPipe | ||
| from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs | ||
| from torchchat.model import ModelArgs, Transformer, TransformerArgs | ||
| from torchchat.utils.build_utils import set_precision | ||
|
|
||
|
|
@@ -189,23 +189,49 @@ def _create_padded_prompts( | |
|
|
||
| def _batch_decode_next_tokens( | ||
| output: torch.Tensor, | ||
| pos: int, | ||
| pos: List[int], | ||
| step: int = -1, | ||
| temperature: float = 1.0, | ||
| topk: int = 10, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Decode the next token for each prompt in the batch. | ||
| Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding. | ||
|
|
||
| Args: | ||
| output (torch.Tensor): The output tensor to decode. | ||
| pos: the position of the `output` to decode in the sequence length dimension. | ||
| pos (List[int]): The positions of the `output` to decode in the sequence length dimension. | ||
| step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token. | ||
| temperature (float): Sampling temperature for non-deterministic decoding. | ||
|
|
||
| Returns: | ||
| Decoded token ids. | ||
| torch.Tensor: Decoded token ids. | ||
| """ | ||
| # Take the next token logits for each prompt | ||
| next_token_logits = output[:, pos, :] | ||
| # Argmax (deterministic) TODO: add temperature | ||
| next_token = torch.argmax(next_token_logits, dim=-1) | ||
| # Token ids in int tensor form | ||
| return next_token | ||
| batch_size, seq_len, vocab_size = output.shape | ||
|
|
||
| if step != -1: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can |
||
| next_token_logits = output[:, 0, :] | ||
| else: | ||
| # get the logits for each prompt at the specified positions | ||
| next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, why "-1"? |
||
|
|
||
| if temperature != 1.0: | ||
| next_token_logits = next_token_logits / temperature | ||
|
Comment on lines
+217
to
+218
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we do the division unconditionally? |
||
|
|
||
| # Uses top-k sampling if temperature is not 1.0, otherwise use argmax | ||
| if temperature != 1.0: | ||
| top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size | ||
| top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1) | ||
| probs = torch.softmax(top_k_logits, dim=-1) | ||
| next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) | ||
| next_tokens = top_k_indices.gather( | ||
| -1, next_token_indices.unsqueeze(-1) | ||
| ).squeeze(-1) | ||
|
Comment on lines
+222
to
+228
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do you mind adding more comments here for the |
||
| else: | ||
| # Argmax (deterministic) | ||
| next_tokens = torch.argmax(next_token_logits, dim=-1) | ||
|
|
||
| logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}") | ||
| return next_tokens | ||
|
|
||
|
|
||
| def _update_padded_sequence( | ||
|
|
@@ -218,11 +244,32 @@ def _update_padded_sequence( | |
| # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") | ||
|
|
||
|
|
||
| # Decode token id into string and print it | ||
| def _decode_in_flight(token, tokenizer, tp_rank): | ||
| """decode token ids for all prompts in the batch and log them""" | ||
| token_str = tokenizer.decode(token.tolist()) | ||
| # print the token string on tp rank 0 | ||
| if tp_rank == 0: | ||
| logger.info( | ||
| f"{color.green} responses ====>>>> " | ||
| f"{color.blue} {token_str} {color.reset}" | ||
| ) | ||
|
|
||
|
|
||
| def _cleanup(): | ||
| dist.barrier() | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| prompt = [ | ||
| "What is Snow?", | ||
| "Who is Santa Claus?", | ||
| "Where does Santa live?", | ||
| # "Who is Abraham Lincoln?", | ||
| # "How are models trained?", | ||
| ] | ||
|
|
||
|
|
||
| def main(args): | ||
| model_name = args.model_name | ||
| pp_degree = args.pp | ||
|
|
@@ -293,7 +340,7 @@ def main(args): | |
| # Batch size. Since we push batches dynamically through the pipeline rather | ||
| # than chunking them, this is effectively micro-batch size in pipeline | ||
| # sense. Thus it is interchangeable with micro-batch size below. | ||
| batch_size = 4 | ||
| batch_size = len(prompt) | ||
| seqlen_prefill = 1024 # sequence length | ||
| dim = 4096 # embedding dimension | ||
|
|
||
|
|
@@ -331,7 +378,9 @@ def main(args): | |
|
|
||
| # Helper function to get example inputs and outputs for the stages. | ||
| def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) | ||
| mb_ids = torch.randint( | ||
| 0, config.vocab_size, (batch_size, seqlen), device=device | ||
| ) | ||
| activation = torch.rand( | ||
| batch_size, seqlen, dim, device=device, dtype=model_dtype | ||
| ) | ||
|
|
@@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # pipelining effect. | ||
| prefiller = ScheduleGPipe(prefill_stage, 1) | ||
|
|
||
| prompt = [ | ||
| "What is a computer?", | ||
| "Where does Santa live?", | ||
| "Who is Abraham Lincoln?", | ||
| "How are models trained?", | ||
| ] | ||
|
|
||
| start_pos = 0 | ||
|
|
||
| # Need these global ids due to the API definition of dist.send and recv | ||
|
|
@@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| padded_sequence, prompt_lengths = _create_padded_prompts( | ||
| input_ids, tokenizer, seqlen_prefill, start_pos, device | ||
| ) | ||
| # TODO: figure out how to set input_pos for each prompt in the batch then we | ||
| # can remove this limitation. | ||
| s = set(prompt_lengths) | ||
| assert len(s) == 1, f"prompt_lengths should be the same, got {s}" | ||
|
|
||
| # Need these global ids due to the API definition of dist.send and recv | ||
| first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) | ||
|
|
@@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # New token generated each iteration | ||
| # need a row dimension for each prompt in the batch | ||
| new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) | ||
| logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: for debugging only? |
||
| # Store the generated tokens | ||
| res = [] | ||
|
|
||
|
|
@@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
| ) | ||
|
|
||
| # Decode token id into string and print it | ||
| def decode_in_flight(token): | ||
| # Make a 2D tensor with ids on row dimension | ||
| unsqueezed = torch.unsqueeze(token, 1) | ||
| token_str = tokenizer.decode(unsqueezed.tolist()) | ||
| if tp_rank == 0: | ||
| logger.info( | ||
| f"{color.green} responses ====>>>> " | ||
| f"{color.blue} {token_str} {color.reset}" | ||
| ) | ||
|
|
||
| # Decode the output -- first generated token | ||
| if pp_rank == last_pp_rank: | ||
| new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1) | ||
| logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") | ||
| new_token = _batch_decode_next_tokens(output, prompt_lengths) | ||
| res.append(new_token) | ||
| if not args.disable_in_flight_decode: | ||
| decode_in_flight(new_token) | ||
| _decode_in_flight(new_token, tokenizer, tp_rank) | ||
|
Comment on lines
463
to
+464
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: put |
||
|
|
||
| # seqlen = 1 now | ||
| seqlen_decode = 1 | ||
|
|
@@ -482,10 +511,11 @@ def decode_in_flight(token): | |
|
|
||
| # Decode the output | ||
| if pp_rank == last_pp_rank: | ||
| new_token = _batch_decode_next_tokens(output, 0) | ||
| # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove log? |
||
| new_token = _batch_decode_next_tokens(output, prompt_lengths, step) | ||
| res.append(new_token) | ||
| if not args.disable_in_flight_decode: | ||
| decode_in_flight(new_token) | ||
| _decode_in_flight(new_token, tokenizer, tp_rank) | ||
|
|
||
| # Increment input position | ||
| input_pos += 1 | ||
|
|
@@ -499,12 +529,17 @@ def decode_in_flight(token): | |
| # output formatted response via last pp group and tp rank 0 | ||
| if pp_rank == last_pp_rank and tp_rank == 0: | ||
| # `res` is a list of tensors, each being a batch of generated token ids | ||
| res = torch.stack(res, dim=1) | ||
| res_list = res.tolist() | ||
| response = tokenizer.decode(res_list) | ||
| for i in range(len(response)): | ||
| logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}") | ||
| logger.info(f"Response: {color.red}{response[i]} {color.reset}") | ||
|
|
||
| res_stacked = torch.stack(res, dim=1) | ||
| res_list = res_stacked.tolist() | ||
|
|
||
| # Decode the output as comprehension instead of loop | ||
| responses = [tokenizer.decode(sequence) for sequence in res_list] | ||
|
Comment on lines
+536
to
+537
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, did the previous code not work in case of variable length? Just curious. |
||
|
|
||
| # Show prompts and responses | ||
| for prompt_text, response_text in zip(prompt, responses): | ||
| logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") | ||
| logger.info(f"Response: {color.red}{response_text} {color.reset}") | ||
|
|
||
| # Cleanup | ||
| _cleanup() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if torchchat's generate also have the
temperatureoption? Shall we think about how to connect with generate in next steps?