This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
[distributed] prefill (single and multi-prompt) and single prompt generation and decoding #1133
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
28598e7
add encode string and create_padded_input functions
lessw2020 13bdcb3
update prefill functions
lessw2020 54d895b
ensure 8B is default
lessw2020 4e9771c
add typing to added functions
lessw2020 50d451a
ruff formatting
lessw2020 1ea7960
enable multi-batch prefill
lessw2020 0f28976
Merge branch 'main' into lessw2020/prefill
lessw2020 ef0a03d
decoding start
lessw2020 57edb47
decoding comms working
lessw2020 0db7c60
decoding comms working next token send/receive
lessw2020 13fbee6
first decoded token
lessw2020 e3fe1bf
second decoded token
lessw2020 fe9dae9
single prompt prefill + decoding all working
lessw2020 33b2549
add _update_padded_sequence
lessw2020 2bf85bf
Merge branch 'lessw2020/prefill' of github.com:pytorch/torchchat into…
lessw2020 ab9a24d
add refined output, update force_download to 3-8B
lessw2020 30f70b8
Merge branch 'main' into lessw2020/prefill
lessw2020 2a8ea19
pr_feedback, ruff formatting
lessw2020 89eda74
remove debug related logging
lessw2020 6bb2725
Merge branch 'main' into lessw2020/prefill
Jack-Khuu 9242dfb
Merge branch 'main' into lessw2020/prefill
lessw2020 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,14 +8,12 @@ | |
| import os | ||
| from pathlib import Path | ||
| from types import SimpleNamespace | ||
| from typing import Any, Dict, Optional | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| # Run command: | ||
| # torchrun --nproc-per-node 4 dist_run.py | ||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.pipelining import PipelineStage, ScheduleGPipe | ||
|
|
||
|
|
||
| from distributed.logging_utils import SingletonLogger | ||
|
|
||
|
|
@@ -25,19 +23,17 @@ | |
| get_hf_weight_map_and_path, | ||
| load_safetensor_weights, | ||
| ) | ||
|
|
||
| from distributed.utils import ( | ||
| bytes_to_readable, | ||
| Color as color, | ||
| GPUMemoryMonitor, | ||
| CUDATrackTime, | ||
| get_module_size, | ||
| get_num_params, | ||
| bytes_to_readable, | ||
| TrackTime, | ||
| CUDATrackTime, | ||
| GPUMemoryMonitor, | ||
| ) | ||
|
|
||
| from distributed.verification_utils import find_cpu_tensors | ||
| from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer | ||
| from torch.distributed.pipelining import PipelineStage, ScheduleGPipe | ||
| from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs | ||
| from torchchat.model import ModelArgs, Transformer | ||
| from torchchat.utils.build_utils import set_precision | ||
|
|
||
|
|
@@ -136,6 +132,99 @@ def _load_model_weights(stage_module, distribution, device, model_config): | |
| raise ValueError(f"Missing {num_missing_weights} weights") | ||
|
|
||
|
|
||
| def _encode_strings( | ||
| strings: List[str], | ||
| tokenizer, | ||
| bos: bool = True, | ||
| device: torch.device = "cuda:0", | ||
| dtype=torch.int64, | ||
| ) -> List[torch.Tensor]: | ||
| """Encode a list of prompt strings into a list of tensor token ids.""" | ||
| encoded_list = [] | ||
| for string in strings: | ||
| tokens = tokenizer.encode(string) | ||
| if bos: | ||
| tokens = [tokenizer.bos_id()] + tokens | ||
| encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device)) | ||
| return encoded_list | ||
|
|
||
|
|
||
| def _create_padded_prompts( | ||
| input_ids_list: List[torch.Tensor], | ||
| tokenizer, | ||
| seqlen: int, | ||
| start_pos: int, | ||
| device: torch.device, | ||
| pad_token_id: Optional[int] = None, | ||
| ) -> Tuple[torch.Tensor, List[int]]: | ||
| """ | ||
| Create a padded tensor for multiple encoded input prompts. | ||
|
|
||
| Returns: | ||
| Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths. | ||
| """ | ||
| pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id() | ||
|
|
||
| # Find the maximum prompt length | ||
| max_prompt_len = max(ids.size(0) for ids in input_ids_list) | ||
|
|
||
| # Calculate the buffer size | ||
| max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len)) | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| token_buffer_size = max_prompt_len + max_new_tokens | ||
|
|
||
| # Create the padded batch tensor | ||
| batch_size = len(input_ids_list) | ||
| batch_seq = torch.full( | ||
| (batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device | ||
| ) | ||
|
|
||
| prompt_lengths = [] | ||
| for i, input_ids in enumerate(input_ids_list): | ||
| prompt_len = input_ids.size(0) | ||
| batch_seq[i, :prompt_len] = input_ids | ||
| prompt_lengths.append(prompt_len) | ||
|
|
||
| return batch_seq, prompt_lengths | ||
|
|
||
|
|
||
| def _batch_decode_next_tokens( | ||
| output: torch.Tensor, | ||
| prompt_lengths: List[int], | ||
| tokenizer, | ||
| ) -> List[Tuple[int, str]]: | ||
| """ | ||
| Decode the next token for each prompt in the batch. | ||
|
|
||
| Returns: | ||
| List[Tuple[int, str]]: List of tuples containing the next token id and its | ||
| decoded string for each prompt in the batch. | ||
| """ | ||
| batch_size = output.shape[0] | ||
| results = [] | ||
|
|
||
| for i in range(batch_size): | ||
| next_token_logits = output[i, prompt_lengths[i] - 1, :] | ||
|
|
||
| # Argmax (deterministic) TODO: add temperature | ||
| next_token = torch.argmax(next_token_logits, dim=-1) | ||
|
|
||
| next_token_decoded = tokenizer.decode([next_token.item()]) | ||
| results.append((next_token.item(), next_token_decoded)) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| def _update_padded_sequence( | ||
| padded_sequence: torch.Tensor, | ||
| x_recv: torch.Tensor, | ||
| res, | ||
| prompt_lengths: List[int], | ||
| ) -> None: | ||
| for i in range(len(prompt_lengths)): | ||
| prompt_lengths[i] += 1 | ||
| padded_sequence[i, prompt_lengths[i] - 1] = x_recv | ||
|
|
||
|
|
||
| def _cleanup(): | ||
| dist.barrier() | ||
| dist.destroy_process_group() | ||
|
|
@@ -180,6 +269,17 @@ def main(args): | |
| pp_mesh = device_mesh["pp"] | ||
| tp_rank = tp_mesh.get_local_rank() | ||
| pp_rank = pp_mesh.get_local_rank() | ||
| tp_group = tp_mesh.get_group() | ||
| pp_group = pp_mesh.get_group() | ||
|
|
||
| logger.info(f"review: {pp_group=}, {tp_group= }") | ||
|
|
||
| logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n") | ||
| # TODO - this assumes 1D mesh, need to update for 2D+ mesh | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pp_group_size = pp_mesh.size() | ||
| tp_group_size = tp_mesh.size() | ||
|
|
||
| logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}") | ||
|
|
||
| # Assuming same number of GPUs per node | ||
| device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") | ||
|
|
@@ -198,9 +298,10 @@ def main(args): | |
| if rank == 0: | ||
| logger.info(f"Model: {model}") | ||
|
|
||
| mbs = 2 # number of micro-batches | ||
| mbs = 1 # number of micro-batches | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mb_size = 1 # micro-batch size | ||
| batch_size = mbs * mb_size # total batch size | ||
|
|
||
| seqlen = 4096 # sequence length | ||
| dim = 4096 # embedding dimension | ||
| assert seqlen % sp_degree == 0 | ||
|
|
@@ -213,8 +314,10 @@ def main(args): | |
|
|
||
| # Load weights | ||
| logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
| with TrackTime("cuda") as timer: | ||
| _load_model_weights(model, distribution, device=device, model_config=config) | ||
|
|
||
| with CUDATrackTime() as timer: | ||
| _load_model_weights(model, hf_model_name, device=device, model_config=config) | ||
|
|
||
| logger.info( | ||
| f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" | ||
| ) | ||
|
|
@@ -226,9 +329,8 @@ def main(args): | |
| logger.info( | ||
| f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n" | ||
| ) | ||
|
|
||
| # Setup input position | ||
| # input_pos for prefill: a list of increasing integers from 0 to seqlen | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen | ||
| input_pos = torch.arange(seqlen, device=device) | ||
| model.setup_input_pos(input_pos) | ||
| model.eval() | ||
|
|
@@ -249,41 +351,129 @@ def main(args): | |
| if len(cpu_tensors) > 0: | ||
| raise ValueError("Found cpu tensors in stage") | ||
|
|
||
| # TODO: this can likely be removed after we prove out a few more models | ||
| # verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm. | ||
| # dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod) | ||
| # logger.info( | ||
| # f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}" | ||
| # ) | ||
| # assert ( | ||
| # len(dtype_count) == 2 | ||
| # ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}" | ||
| prompt = [ | ||
| "What is snow?", | ||
| ] | ||
|
|
||
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) | ||
| logger.info(f"Input: {input_ids.dtype=}, {input_ids.shape=}, {input_ids.device=}") | ||
| """ | ||
| "What is the capital of France?", | ||
| "What is your name?", | ||
| "What is the capital of Japan?", | ||
| "When is Christmas?", | ||
| "Where does Santa Claus live?", | ||
| "What is the capital of the United States?", | ||
| "What is the capital of China?", | ||
| "What is the capital of Russia?", | ||
| "What is PyTorch?", | ||
| "What is the capital of India?", | ||
| "What is an LLM?", | ||
| "What is the capital of Brazil?", | ||
| "What is the capital of Mexico?", | ||
| "What is the capital of Argentina?", | ||
| "What is the capital of Canada?", | ||
| ] | ||
| """ | ||
|
|
||
| schedule = ScheduleGPipe(stage, mbs) | ||
| logger.info(f"Created schedule: {schedule}") | ||
|
|
||
| with torch.no_grad(): # .inference_mode(): | ||
| if pp_rank == 0: | ||
| output = schedule.step(input_ids) | ||
| else: | ||
| output = schedule.step() | ||
| start_pos = 0 | ||
|
|
||
| if pp_rank == pp_degree - 1 and tp_rank == 0: | ||
| logger.info(f"Output: {output}") | ||
| # encode the prompt | ||
| input_ids = _encode_strings( | ||
| prompt, tokenizer, bos=True, device=device, dtype=torch.int64 | ||
| ) | ||
| logger.info(f"{input_ids[0:8]=}") | ||
|
|
||
| # show peak memory stats for this stage | ||
| res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats() | ||
| logger.info( | ||
| f"{color.blue} Memory used: {color.green}{res_mem_pct:.3f} %, {color.magenta}{res_mem_gib:.3f} GB{color.reset}" | ||
| # create a padded tensor for the input prompt | ||
| padded_sequence, prompt_lengths = _create_padded_prompts( | ||
| input_ids, tokenizer, seqlen, start_pos, device | ||
| ) | ||
| logger.info(f"{prompt_lengths=}") | ||
| logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}") | ||
| if len(prompt_lengths) > 1: | ||
| logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}") | ||
|
Comment on lines
+390
to
+393
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we clean up these logs a little bit? And why do we need the +1? |
||
|
|
||
| schedule = ScheduleGPipe(stage, mbs) | ||
| logger.info(f"Created schedule: {schedule}") | ||
|
|
||
| # with CUDATrackTime() as timer: | ||
| first_pp_group = 0 | ||
| last_pp_group = pp_group_size - 1 | ||
|
Comment on lines
+399
to
+400
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is "group" a number? Do you mean rank in PP group? |
||
|
|
||
| x_recv = torch.zeros(1, device=device, dtype=torch.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we add a comment on the "1" part for better readability? thx. |
||
| logger.info(f"{x_recv.shape=}") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove log if not super important? |
||
|
|
||
| last_global_rank = world_size - 1 | ||
| res = [] | ||
| dst = None | ||
| src = None | ||
|
|
||
| if pp_rank == last_pp_group: | ||
| dst = dist.get_global_rank(pp_group, 0) | ||
| elif pp_rank == 0: | ||
| src = dist.get_global_rank(pp_group, last_pp_group) | ||
|
|
||
| # Decoding | ||
| num_tokens = 40 | ||
|
|
||
| with torch.no_grad(): | ||
| for step in range(num_tokens): | ||
| # first | ||
| if pp_rank == 0: | ||
| schedule.step(padded_sequence) | ||
| # only receive if not last step | ||
| if step < num_tokens - 1: | ||
| dist.recv( | ||
| x_recv, | ||
| src, | ||
| group=pp_group, | ||
| ) | ||
| _update_padded_sequence( | ||
| padded_sequence, x_recv, res, prompt_lengths | ||
| ) | ||
|
|
||
| # last | ||
| elif pp_rank == last_pp_group: | ||
| output = schedule.step() | ||
| # need to decode the output | ||
| decode_results = _batch_decode_next_tokens( | ||
| output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer | ||
| ) | ||
| if tp_rank == 0: | ||
| logger.info( | ||
| f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}" | ||
| ) | ||
|
|
||
| next_token = torch.tensor([decode_results[0][0]], device=device) | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| res.append(decode_results[0][1]) | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # increment prompt lengths for next token | ||
| for i in range(len(prompt_lengths)): | ||
| prompt_lengths[i] += 1 | ||
| # logger.info( | ||
| # f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}" | ||
| # ) | ||
|
|
||
| # only send if not last step | ||
| if step < (num_tokens - 1): | ||
| dist.send( | ||
| next_token, | ||
| dst, | ||
| pp_group, | ||
| ) | ||
|
|
||
| # middle pp ranks | ||
| else: | ||
| schedule.step() | ||
|
|
||
| # output formatted response via last pp group and tp rank 0 | ||
| if pp_rank == last_pp_group and tp_rank == 0: | ||
| logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}") | ||
| formatted_response = "".join(res) | ||
| logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$") | ||
|
|
||
| logger.info( | ||
| f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" | ||
| ) | ||
|
|
||
| _cleanup() | ||
|
|
||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | ||
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | ||
| print("Model weights downloaded") | ||
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | ||
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | ||
| print("Model weights and tokenizer downloaded") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.