diff --git a/dist_run.py b/dist_run.py index 79a3d2f84..5282cb6ad 100644 --- a/dist_run.py +++ b/dist_run.py @@ -8,14 +8,12 @@ import os from pathlib import Path from types import SimpleNamespace -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple # Run command: # torchrun --nproc-per-node 4 dist_run.py import torch import torch.distributed as dist -from torch.distributed.pipelining import PipelineStage, ScheduleGPipe - from distributed.logging_utils import SingletonLogger @@ -25,19 +23,17 @@ get_hf_weight_map_and_path, load_safetensor_weights, ) - from distributed.utils import ( + bytes_to_readable, Color as color, - GPUMemoryMonitor, + CUDATrackTime, get_module_size, get_num_params, - bytes_to_readable, - TrackTime, - CUDATrackTime, + GPUMemoryMonitor, ) - from distributed.verification_utils import find_cpu_tensors -from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs from torchchat.model import ModelArgs, Transformer from torchchat.utils.build_utils import set_precision @@ -136,6 +132,99 @@ def _load_model_weights(stage_module, distribution, device, model_config): raise ValueError(f"Missing {num_missing_weights} weights") +def _encode_strings( + strings: List[str], + tokenizer, + bos: bool = True, + device: torch.device = "cuda:0", + dtype=torch.int64, +) -> List[torch.Tensor]: + """Encode a list of prompt strings into a list of tensor token ids.""" + encoded_list = [] + for string in strings: + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device)) + return encoded_list + + +def _create_padded_prompts( + input_ids_list: List[torch.Tensor], + tokenizer, + seqlen: int, + start_pos: int, + device: torch.device, + pad_token_id: Optional[int] = None, +) -> Tuple[torch.Tensor, List[int]]: + """ + Create a padded tensor for multiple encoded input prompts. + + Returns: + Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths. + """ + pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id() + + # Find the maximum prompt length + max_prompt_len = max(ids.size(0) for ids in input_ids_list) + + # Calculate the buffer size + max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len)) + token_buffer_size = max_prompt_len + max_new_tokens + + # Create the padded batch tensor + batch_size = len(input_ids_list) + batch_seq = torch.full( + (batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device + ) + + prompt_lengths = [] + for i, input_ids in enumerate(input_ids_list): + prompt_len = input_ids.size(0) + batch_seq[i, :prompt_len] = input_ids + prompt_lengths.append(prompt_len) + + return batch_seq, prompt_lengths + + +def _batch_decode_next_tokens( + output: torch.Tensor, + prompt_lengths: List[int], + tokenizer, +) -> List[Tuple[int, str]]: + """ + Decode the next token for each prompt in the batch. + + Returns: + List[Tuple[int, str]]: List of tuples containing the next token id and its + decoded string for each prompt in the batch. + """ + batch_size = output.shape[0] + results = [] + + for i in range(batch_size): + next_token_logits = output[i, prompt_lengths[i] - 1, :] + + # Argmax (deterministic) TODO: add temperature + next_token = torch.argmax(next_token_logits, dim=-1) + + next_token_decoded = tokenizer.decode([next_token.item()]) + results.append((next_token.item(), next_token_decoded)) + + return results + + +def _update_padded_sequence( + padded_sequence: torch.Tensor, + x_recv: torch.Tensor, + res, + prompt_lengths: List[int], +) -> None: + for i in range(len(prompt_lengths)): + prompt_lengths[i] += 1 + padded_sequence[i, prompt_lengths[i] - 1] = x_recv + + def _cleanup(): dist.barrier() dist.destroy_process_group() @@ -180,6 +269,17 @@ def main(args): pp_mesh = device_mesh["pp"] tp_rank = tp_mesh.get_local_rank() pp_rank = pp_mesh.get_local_rank() + tp_group = tp_mesh.get_group() + pp_group = pp_mesh.get_group() + + logger.info(f"review: {pp_group=}, {tp_group= }") + + logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n") + # TODO - this assumes 1D mesh, need to update for 2D+ mesh + pp_group_size = pp_mesh.size() + tp_group_size = tp_mesh.size() + + logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}") # Assuming same number of GPUs per node device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") @@ -198,9 +298,10 @@ def main(args): if rank == 0: logger.info(f"Model: {model}") - mbs = 2 # number of micro-batches + mbs = 1 # number of micro-batches mb_size = 1 # micro-batch size batch_size = mbs * mb_size # total batch size + seqlen = 4096 # sequence length dim = 4096 # embedding dimension assert seqlen % sp_degree == 0 @@ -213,8 +314,10 @@ def main(args): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") - with TrackTime("cuda") as timer: - _load_model_weights(model, distribution, device=device, model_config=config) + + with CUDATrackTime() as timer: + _load_model_weights(model, hf_model_name, device=device, model_config=config) + logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" ) @@ -226,9 +329,8 @@ def main(args): logger.info( f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n" ) - - # Setup input position - # input_pos for prefill: a list of increasing integers from 0 to seqlen + + # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen input_pos = torch.arange(seqlen, device=device) model.setup_input_pos(input_pos) model.eval() @@ -249,41 +351,129 @@ def main(args): if len(cpu_tensors) > 0: raise ValueError("Found cpu tensors in stage") - # TODO: this can likely be removed after we prove out a few more models - # verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm. - # dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod) - # logger.info( - # f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}" - # ) - # assert ( - # len(dtype_count) == 2 - # ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}" + prompt = [ + "What is snow?", + ] - input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) - logger.info(f"Input: {input_ids.dtype=}, {input_ids.shape=}, {input_ids.device=}") + """ + "What is the capital of France?", + "What is your name?", + "What is the capital of Japan?", + "When is Christmas?", + "Where does Santa Claus live?", + "What is the capital of the United States?", + "What is the capital of China?", + "What is the capital of Russia?", + "What is PyTorch?", + "What is the capital of India?", + "What is an LLM?", + "What is the capital of Brazil?", + "What is the capital of Mexico?", + "What is the capital of Argentina?", + "What is the capital of Canada?", + ] + """ - schedule = ScheduleGPipe(stage, mbs) - logger.info(f"Created schedule: {schedule}") - with torch.no_grad(): # .inference_mode(): - if pp_rank == 0: - output = schedule.step(input_ids) - else: - output = schedule.step() + start_pos = 0 - if pp_rank == pp_degree - 1 and tp_rank == 0: - logger.info(f"Output: {output}") + # encode the prompt + input_ids = _encode_strings( + prompt, tokenizer, bos=True, device=device, dtype=torch.int64 + ) + logger.info(f"{input_ids[0:8]=}") - # show peak memory stats for this stage - res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats() - logger.info( - f"{color.blue} Memory used: {color.green}{res_mem_pct:.3f} %, {color.magenta}{res_mem_gib:.3f} GB{color.reset}" + # create a padded tensor for the input prompt + padded_sequence, prompt_lengths = _create_padded_prompts( + input_ids, tokenizer, seqlen, start_pos, device ) + logger.info(f"{prompt_lengths=}") + logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}") + if len(prompt_lengths) > 1: + logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}") + + schedule = ScheduleGPipe(stage, mbs) + logger.info(f"Created schedule: {schedule}") + + # with CUDATrackTime() as timer: + first_pp_group = 0 + last_pp_group = pp_group_size - 1 + + x_recv = torch.zeros(1, device=device, dtype=torch.int64) + logger.info(f"{x_recv.shape=}") + + last_global_rank = world_size - 1 + res = [] + dst = None + src = None + + if pp_rank == last_pp_group: + dst = dist.get_global_rank(pp_group, 0) + elif pp_rank == 0: + src = dist.get_global_rank(pp_group, last_pp_group) + + # Decoding + num_tokens = 40 + + with torch.no_grad(): + for step in range(num_tokens): + # first + if pp_rank == 0: + schedule.step(padded_sequence) + # only receive if not last step + if step < num_tokens - 1: + dist.recv( + x_recv, + src, + group=pp_group, + ) + _update_padded_sequence( + padded_sequence, x_recv, res, prompt_lengths + ) + + # last + elif pp_rank == last_pp_group: + output = schedule.step() + # need to decode the output + decode_results = _batch_decode_next_tokens( + output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer + ) + if tp_rank == 0: + logger.info( + f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}" + ) + + next_token = torch.tensor([decode_results[0][0]], device=device) + res.append(decode_results[0][1]) + + # increment prompt lengths for next token + for i in range(len(prompt_lengths)): + prompt_lengths[i] += 1 + # logger.info( + # f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}" + # ) + + # only send if not last step + if step < (num_tokens - 1): + dist.send( + next_token, + dst, + pp_group, + ) + + # middle pp ranks + else: + schedule.step() + + # output formatted response via last pp group and tp rank 0 + if pp_rank == last_pp_group and tp_rank == 0: + logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}") + formatted_response = "".join(res) + logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$") logger.info( f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" ) - _cleanup() diff --git a/distributed/force_download.py b/distributed/force_download.py index f69ce5ba3..76dba8d0c 100644 --- a/distributed/force_download.py +++ b/distributed/force_download.py @@ -1,5 +1,5 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") -model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") -print("Model weights downloaded") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +print("Model weights and tokenizer downloaded") diff --git a/run_dist.sh b/run_dist.sh index a80c1b614..ed087d4e5 100644 --- a/run_dist.sh +++ b/run_dist.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=4,5,6,7 +#export CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=${1:-29501} NGPU=${NGPU:-"4"} LOG_RANK=${LOG_RANK:-0,1,2,3}