This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
[distributed] prefill (single and multi-prompt) and single prompt generation and decoding #1133
Merged
Merged
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
28598e7
add encode string and create_padded_input functions
lessw2020 13bdcb3
update prefill functions
lessw2020 54d895b
ensure 8B is default
lessw2020 4e9771c
add typing to added functions
lessw2020 50d451a
ruff formatting
lessw2020 1ea7960
enable multi-batch prefill
lessw2020 0f28976
Merge branch 'main' into lessw2020/prefill
lessw2020 ef0a03d
decoding start
lessw2020 57edb47
decoding comms working
lessw2020 0db7c60
decoding comms working next token send/receive
lessw2020 13fbee6
first decoded token
lessw2020 e3fe1bf
second decoded token
lessw2020 fe9dae9
single prompt prefill + decoding all working
lessw2020 33b2549
add _update_padded_sequence
lessw2020 2bf85bf
Merge branch 'lessw2020/prefill' of github.com:pytorch/torchchat into…
lessw2020 ab9a24d
add refined output, update force_download to 3-8B
lessw2020 30f70b8
Merge branch 'main' into lessw2020/prefill
lessw2020 2a8ea19
pr_feedback, ruff formatting
lessw2020 89eda74
remove debug related logging
lessw2020 6bb2725
Merge branch 'main' into lessw2020/prefill
Jack-Khuu 9242dfb
Merge branch 'main' into lessw2020/prefill
lessw2020 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,34 +7,23 @@ | |
| import os | ||
| from pathlib import Path | ||
| from types import SimpleNamespace | ||
| from typing import Any, Dict | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| # Run command: | ||
| # torchrun --nproc-per-node 4 dist_run.py | ||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.pipelining import PipelineStage, ScheduleGPipe | ||
|
|
||
|
|
||
| from distributed.logging_utils import SingletonLogger | ||
|
|
||
| # TODO - these are not distributed specific, consider moving to new package | ||
| from distributed.safetensor_utils import ( | ||
| get_hf_config_file, | ||
| get_hf_weight_map_and_path, | ||
| load_safetensor_weights, | ||
| ) | ||
|
|
||
| from distributed.utils import ( | ||
| Color as color, | ||
| GPUMemoryMonitor, | ||
| get_module_size, | ||
| get_num_params, | ||
| bytes_to_readable, | ||
| TrackTime, | ||
| CUDATrackTime, | ||
| ) | ||
|
|
||
| from distributed.safetensor_utils import (get_hf_config_file, | ||
| get_hf_weight_map_and_path, | ||
| load_safetensor_weights) | ||
| from distributed.utils import Color as color | ||
| from distributed.utils import (GPUMemoryMonitor, TrackTime, | ||
| bytes_to_readable, get_module_size, | ||
| get_num_params) | ||
| from distributed.verification_utils import find_cpu_tensors | ||
| from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer | ||
| from torchchat.model import ModelArgs, Transformer | ||
|
|
@@ -52,7 +41,8 @@ | |
|
|
||
| logger = SingletonLogger.get_logger() | ||
|
|
||
| MODEL_NAME = "Transformer-2-7b-chat-hf" | ||
| MODEL_NAME = "Meta-Llama-3-8B" | ||
|
|
||
| NAME_TO_HF_MODEL_ID_AND_DTYPE = { | ||
| "Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), | ||
| "Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), | ||
|
|
@@ -122,6 +112,88 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config): | |
| raise ValueError(f"Missing {num_missing_weights} weights") | ||
|
|
||
|
|
||
| def _encode_strings( | ||
| strings: List[str], | ||
| tokenizer, | ||
| bos: bool = True, | ||
| device: str = "cuda", | ||
lessw2020 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dtype=torch.int64, | ||
| ) -> List[torch.Tensor]: | ||
| """Encode a list of prompt strings into a list of tensor token ids.""" | ||
| encoded_list = [] | ||
| for string in strings: | ||
| tokens = tokenizer.encode(string) | ||
| if bos: | ||
| tokens = [tokenizer.bos_id()] + tokens | ||
| encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device)) | ||
| return encoded_list | ||
|
|
||
|
|
||
| def _create_padded_prompts( | ||
| input_ids_list: List[torch.Tensor], | ||
| tokenizer, | ||
| seqlen: int, | ||
| start_pos: int, | ||
| device: str, | ||
lessw2020 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pad_token_id: Optional[int] = None, | ||
| ) -> Tuple[torch.Tensor, List[int]]: | ||
| """ | ||
| Create a padded tensor for multiple encoded input prompts. | ||
|
|
||
| Returns: | ||
| Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths. | ||
| """ | ||
| pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id() | ||
|
|
||
| # Find the maximum prompt length | ||
| max_prompt_len = max(ids.size(0) for ids in input_ids_list) | ||
|
|
||
| # Calculate the buffer size | ||
| max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len)) | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| token_buffer_size = max_prompt_len + max_new_tokens | ||
|
|
||
| # Create the padded batch tensor | ||
| batch_size = len(input_ids_list) | ||
| batch_seq = torch.full( | ||
| (batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device | ||
| ) | ||
|
|
||
| prompt_lengths = [] | ||
| for i, input_ids in enumerate(input_ids_list): | ||
| prompt_len = input_ids.size(0) | ||
| batch_seq[i, :prompt_len] = input_ids | ||
| prompt_lengths.append(prompt_len) | ||
|
|
||
| return batch_seq, prompt_lengths | ||
|
|
||
|
|
||
| def _batch_decode_next_tokens( | ||
| output: torch.Tensor, | ||
| prompt_lengths: List[int], | ||
| tokenizer, | ||
| ) -> List[Tuple[int, str]]: | ||
| """ | ||
| Decode the next token for each prompt in the batch. | ||
|
|
||
| Returns: | ||
| List[Tuple[int, str]]: List of tuples containing the next token id and its | ||
| decoded string for each prompt in the batch. | ||
| """ | ||
| batch_size = output.shape[0] | ||
| results = [] | ||
|
|
||
| for i in range(batch_size): | ||
| next_token_logits = output[i, prompt_lengths[i] - 1, :] | ||
|
|
||
| # Argmax (deterministic) TODO: add temperature | ||
| next_token = torch.argmax(next_token_logits, dim=-1) | ||
|
|
||
| next_token_decoded = tokenizer.decode([next_token.item()]) | ||
| results.append((next_token.item(), next_token_decoded)) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| def _cleanup(): | ||
| dist.barrier() | ||
| dist.destroy_process_group() | ||
|
|
@@ -133,8 +205,8 @@ def main(): | |
| gpu_memory_monitor = GPUMemoryMonitor("cuda") | ||
| logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") | ||
|
|
||
| config = ModelArgs.from_name(MODEL_NAME).transformer_args['text'] | ||
| logger.info(f"Chat Model Config: {config}") | ||
| config = ModelArgs.from_name(MODEL_NAME).transformer_args["text"] | ||
| logger.info(f"Chat Model Name: {MODEL_NAME}\nModel Config: {config}") | ||
|
|
||
| tokenizer = _build_chat_tokenizer() | ||
| logger.info(f"built tokenizer {tokenizer=}") | ||
|
|
@@ -182,11 +254,12 @@ def main(): | |
|
|
||
| # Distribute model on TP mesh | ||
| model.distribute(tp_mesh) | ||
| logger.info(f"Model: {model}") | ||
| # logger.info(f"Model: {model}") | ||
lessw2020 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mbs = 2 # number of micro-batches | ||
| mbs = 4 # number of micro-batches | ||
| mb_size = 1 # micro-batch size | ||
| batch_size = mbs * mb_size # total batch size | ||
|
|
||
| seqlen = 4096 # sequence length | ||
| dim = 4096 # embedding dimension | ||
| assert seqlen % sp_degree == 0 | ||
|
|
@@ -199,7 +272,7 @@ def main(): | |
|
|
||
| # Load weights | ||
| logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
| with TrackTime("cuda") as timer: | ||
| with TrackTime() as timer: | ||
lessw2020 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| _load_model_weights(model, hf_model_name, device=device, model_config=config) | ||
| logger.info( | ||
| f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" | ||
|
|
@@ -212,9 +285,8 @@ def main(): | |
| logger.info( | ||
| f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n" | ||
| ) | ||
|
|
||
| # Setup input position | ||
| # input_pos for prefill: a list of increasing integers from 0 to seqlen | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| input_pos = torch.arange(seqlen, device=device) | ||
| model.setup_input_pos(input_pos) | ||
| model.eval() | ||
|
|
@@ -235,30 +307,47 @@ def main(): | |
| if len(cpu_tensors) > 0: | ||
| raise ValueError("Found cpu tensors in stage") | ||
|
|
||
| # TODO: this can likely be removed after we prove out a few more models | ||
| # verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm. | ||
| # dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod) | ||
| # logger.info( | ||
| # f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}" | ||
| # ) | ||
| # assert ( | ||
| # len(dtype_count) == 2 | ||
| # ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}" | ||
| prompt = [ | ||
| "What is the capital of France?", | ||
| "What is snow?", | ||
| "What is your name?", | ||
| "What is the capital of Japan?", | ||
| ] | ||
| start_pos = 0 | ||
|
|
||
| # encode the prompt | ||
| input_ids = _encode_strings( | ||
| prompt, tokenizer, bos=True, device=device, dtype=torch.int64 | ||
| ) | ||
| logger.info(f"{input_ids[0:8]=}") | ||
|
|
||
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) | ||
| logger.info(f"Input: {input_ids.dtype=}, {input_ids.shape=}, {input_ids.device=}") | ||
| # create a padded tensor for the input prompt | ||
| padded_sequence, prompt_lengths = _create_padded_prompts( | ||
| input_ids, tokenizer, seqlen, start_pos, device | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: (1) why padding the input_ids requires the tokenizer again? (2) maybe device can be inferred from input_ids? |
||
| logger.info(f"{prompt_lengths=}") | ||
| logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}") | ||
| if len(prompt_lengths) > 1: | ||
| logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}") | ||
|
Comment on lines
+390
to
+393
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we clean up these logs a little bit? And why do we need the +1? |
||
|
|
||
| schedule = ScheduleGPipe(stage, mbs) | ||
| logger.info(f"Created schedule: {schedule}") | ||
|
|
||
| with torch.no_grad(): # .inference_mode(): | ||
| if pp_rank == 0: | ||
| schedule.step(input_ids) | ||
| schedule.step(padded_sequence) | ||
| else: | ||
| output = schedule.step() | ||
|
|
||
| # Decoding | ||
| if pp_rank == pp_degree - 1 and tp_rank == 0: | ||
| logger.info(f"Output: {output}") | ||
| decode_results = _batch_decode_next_tokens( | ||
| output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer | ||
| ) | ||
|
|
||
| logger.info( | ||
| f"\n\n{color.green} Prefill responses ====>>>> {color.blue} {decode_results=} \n{color.reset}" | ||
| ) | ||
|
|
||
| # show peak memory stats for this stage | ||
| res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats() | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.