diff --git a/dist_run.py b/dist_run.py index 3a8f16a0b..ca65a5b26 100644 --- a/dist_run.py +++ b/dist_run.py @@ -31,7 +31,6 @@ get_num_params, GPUMemoryMonitor, ) -from distributed.verification_utils import find_cpu_tensors from torch.distributed.pipelining import PipelineStage, ScheduleGPipe from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs from torchchat.model import ModelArgs, Transformer @@ -219,10 +218,9 @@ def _update_padded_sequence( new_token: torch.Tensor, prompt_lengths: List[int], ) -> None: - # TODO: this is a hacky way to update the padded sequence: when there is - # more than one prompt, the for loop and the assignment is incompatible. for i in range(len(prompt_lengths)): - padded_sequence[i, prompt_lengths[i]] = new_token + padded_sequence[i, prompt_lengths[i]] = new_token[i, 0] + # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") def _cleanup(): @@ -242,7 +240,7 @@ def main(args): distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}") - config = ModelArgs.from_name(distribution).transformer_args['text'] + config = ModelArgs.from_name(distribution).transformer_args["text"] logger.info(f"Chat Model Config: {config}") tokenizer = _build_chat_tokenizer(model_name) @@ -295,7 +293,7 @@ def main(args): logger.info(f"Model: {model}") mbs = 1 # number of micro-batches - mb_size = 1 # micro-batch size + mb_size = 5 # micro-batch size batch_size = mbs * mb_size # total batch size seqlen = 4096 # sequence length @@ -343,6 +341,10 @@ def main(args): prompt = [ "What is snow?", + "Where does Santa Claus live?", + "What is PyTorch?", + "Write a poem about the beauty of the night sky.", + "What is the capital of France, Germany and Switzerland?", ] """ @@ -366,17 +368,23 @@ def main(args): start_pos = 0 + # pipeline comms setup + first_pp_rank = 0 + last_pp_rank = pp_group_size - 1 + + # Need these global ids due to the API definition of dist.send and recv + first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) + last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) + # encode the prompt input_ids = _encode_strings( prompt, tokenizer, bos=True, device=device, dtype=torch.int64 ) - logger.info(f"{input_ids[0:8]=}") # create a padded tensor for the input prompt padded_sequence, prompt_lengths = _create_padded_prompts( input_ids, tokenizer, seqlen, start_pos, device ) - logger.info(f"{prompt_lengths=}") # create schedule schedule = ScheduleGPipe(stage, mbs) @@ -389,8 +397,10 @@ def main(args): last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) # New token generated each iteration - new_token = torch.zeros(1, device=device, dtype=torch.int64) - res = [] + total_prompts = len(prompt_lengths) + # need a new token dimension (row) for each prompt in the batch + new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64) + res = [[] for _ in range(total_prompts)] num_tokens = 40 # Decoding @@ -415,8 +425,11 @@ def main(args): f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" ) # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] - new_token = torch.tensor([decode_results[0][0]], device=device) - res.append(decode_results[0][1]) + for i in range(len(decode_results)): + res[i].append(decode_results[i][1]) + new_token[i, 0] = torch.tensor( + [decode_results[i][0]], device=device + ) # decode_results[i][0] # sendrecv between last and first ranks, only if: # first_pp_rank != last_pp_rank. @@ -435,20 +448,27 @@ def main(args): # Update input sequence with new token if pp_rank == first_pp_rank: - _update_padded_sequence( - padded_sequence, new_token, prompt_lengths - ) + _update_padded_sequence(padded_sequence, new_token, prompt_lengths) # increment prompt lengths for next token for i in range(len(prompt_lengths)): prompt_lengths[i] += 1 + # Display the decoding results + # output formatted response via last pp group and tp rank 0 if pp_rank == last_pp_rank and tp_rank == 0: - logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}") - formatted_response = " ".join(res) - logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$") + for i in range(len(prompt_lengths)): + logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}") + + # TODO: resolve issue with llama2-7b-chat model and "".join + if model_name != "llama2-7b-chat": + formatted_response = "".join(res[i]) + else: + formatted_response = " ".join(res[i]) + logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n") + # Cleanup logger.info( f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" ) @@ -457,7 +477,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys()) + parser.add_argument( + "model_name", + type=str, + help="Name of the model to load", + choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), + ) parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") args = parser.parse_args() diff --git a/run_dist.sh b/run_dist.sh index ed087d4e5..e6e3bb133 100644 --- a/run_dist.sh +++ b/run_dist.sh @@ -4,4 +4,4 @@ NGPU=${NGPU:-"4"} LOG_RANK=${LOG_RANK:-0,1,2,3} torchrun --nproc-per-node=$NGPU --master_port=$PORT \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -dist_run.py +dist_run.py --pp 2 llama3