From 511b5b8d59dcd0ea0ca8fcf79eda4733debead36 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 12:36:28 -0700 Subject: [PATCH 1/7] enable batch decoding, optimize dst/src creation outside of decoding loop --- dist_run.py | 74 ++++++++++++++++++++++++++++++++++------------------- run_dist.sh | 2 +- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/dist_run.py b/dist_run.py index a46dd81f4..5f2590d54 100644 --- a/dist_run.py +++ b/dist_run.py @@ -219,10 +219,11 @@ def _update_padded_sequence( new_token: torch.Tensor, prompt_lengths: List[int], ) -> None: - # TODO: this is a hacky way to update the padded sequence: when there is - # more than one prompt, the for loop and the assignment is incompatible. + for i in range(len(prompt_lengths)): - padded_sequence[i, prompt_lengths[i]] = new_token + prompt_lengths[i] += 1 + padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0] + logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") def _cleanup(): @@ -242,7 +243,7 @@ def main(args): distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}") - config = ModelArgs.from_name(distribution).transformer_args['text'] + config = ModelArgs.from_name(distribution).transformer_args["text"] logger.info(f"Chat Model Config: {config}") tokenizer = _build_chat_tokenizer(model_name) @@ -295,7 +296,7 @@ def main(args): logger.info(f"Model: {model}") mbs = 1 # number of micro-batches - mb_size = 1 # micro-batch size + mb_size = 2 # micro-batch size batch_size = mbs * mb_size # total batch size seqlen = 4096 # sequence length @@ -343,6 +344,7 @@ def main(args): prompt = [ "What is snow?", + "Where does Santa Claus live?", ] """ @@ -366,29 +368,36 @@ def main(args): start_pos = 0 + # pipeline comms setup + first_pp_rank = 0 + last_pp_rank = pp_group_size - 1 + + send_destination = dist.get_global_rank(pp_group, first_pp_rank) + recv_source = dist.get_global_rank(pp_group, last_pp_rank) + # encode the prompt input_ids = _encode_strings( prompt, tokenizer, bos=True, device=device, dtype=torch.int64 ) - logger.info(f"{input_ids[0:8]=}") + logger.info(f"{input_ids[0][0:8]=}") # create a padded tensor for the input prompt padded_sequence, prompt_lengths = _create_padded_prompts( input_ids, tokenizer, seqlen, start_pos, device ) - logger.info(f"{prompt_lengths=}") + logger.info(f"length of each prompt in the batch: {prompt_lengths=}") # create schedule schedule = ScheduleGPipe(stage, mbs) # with CUDATrackTime() as timer: - first_pp_rank = 0 - last_pp_rank = pp_group_size - 1 # New token generated each iteration - new_token = torch.zeros(1, device=device, dtype=torch.int64) - res = [] - num_tokens = 40 + total_prompts = len(prompt_lengths) + # need a new token dimension (row) for each prompt in the batch + new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64) + res = [[] for _ in range(total_prompts)] + num_tokens = 20 # Decoding with torch.no_grad(): @@ -412,39 +421,45 @@ def main(args): f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" ) # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] - new_token = torch.tensor([decode_results[0][0]], device=device) - res.append(decode_results[0][1]) + for i in range(len(decode_results)): + res[i].append(decode_results[i][1]) + new_token[i, 0] = torch.tensor( + [decode_results[i][0]], device=device + ) # decode_results[i][0] + + # increment prompt lengths for next token + for i in range(len(prompt_lengths)): + prompt_lengths[i] += 1 # sendrecv between last and first ranks, only if: # first_pp_rank != last_pp_rank. if pp_rank == last_pp_rank and pp_rank != first_pp_rank: dist.send( new_token, - dst=dist.get_global_rank(pp_group, first_pp_rank), + dst=send_destination, group=pp_group, ) elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: dist.recv( new_token, - src=dist.get_global_rank(pp_group, last_pp_rank), + src=recv_source, group=pp_group, ) # Update input sequence with new token if pp_rank == first_pp_rank: - _update_padded_sequence( - padded_sequence, new_token, prompt_lengths - ) - - # increment prompt lengths for next token - for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 + _update_padded_sequence(padded_sequence, new_token, prompt_lengths) + for i in range(len(prompt_lengths)): + logger.info( + f"next submission: {padded_sequence[i, prompt_lengths[i]-4:prompt_lengths[i]+4]}" + ) # output formatted response via last pp group and tp rank 0 if pp_rank == last_pp_rank and tp_rank == 0: - logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}") - formatted_response = " ".join(res) - logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$") + for i in range(len(prompt_lengths)): + logger.info(f"Prompt:{color.green} {prompt[i]} {color.reset}") + formatted_response = "".join(res[i]) + logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$") logger.info( f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" @@ -454,7 +469,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys()) + parser.add_argument( + "model_name", + type=str, + help="Name of the model to load", + choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), + ) parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") args = parser.parse_args() diff --git a/run_dist.sh b/run_dist.sh index ed087d4e5..e6e3bb133 100644 --- a/run_dist.sh +++ b/run_dist.sh @@ -4,4 +4,4 @@ NGPU=${NGPU:-"4"} LOG_RANK=${LOG_RANK:-0,1,2,3} torchrun --nproc-per-node=$NGPU --master_port=$PORT \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -dist_run.py +dist_run.py --pp 2 llama3 From 221ea94b221066e1370b1111c3b5be7360c9c586 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 12:45:52 -0700 Subject: [PATCH 2/7] remove logging, update formatting for display --- dist_run.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/dist_run.py b/dist_run.py index 5f2590d54..04270d9ca 100644 --- a/dist_run.py +++ b/dist_run.py @@ -223,7 +223,7 @@ def _update_padded_sequence( for i in range(len(prompt_lengths)): prompt_lengths[i] += 1 padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0] - logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") + # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") def _cleanup(): @@ -296,7 +296,7 @@ def main(args): logger.info(f"Model: {model}") mbs = 1 # number of micro-batches - mb_size = 2 # micro-batch size + mb_size = 5 # micro-batch size batch_size = mbs * mb_size # total batch size seqlen = 4096 # sequence length @@ -345,6 +345,9 @@ def main(args): prompt = [ "What is snow?", "Where does Santa Claus live?", + "What is PyTorch?", + "Write a poem about the beauty of the night sky.", + "What is the capital of France, Germany and Switzerland?", ] """ @@ -379,13 +382,11 @@ def main(args): input_ids = _encode_strings( prompt, tokenizer, bos=True, device=device, dtype=torch.int64 ) - logger.info(f"{input_ids[0][0:8]=}") # create a padded tensor for the input prompt padded_sequence, prompt_lengths = _create_padded_prompts( input_ids, tokenizer, seqlen, start_pos, device ) - logger.info(f"length of each prompt in the batch: {prompt_lengths=}") # create schedule schedule = ScheduleGPipe(stage, mbs) @@ -397,7 +398,7 @@ def main(args): # need a new token dimension (row) for each prompt in the batch new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64) res = [[] for _ in range(total_prompts)] - num_tokens = 20 + num_tokens = 40 # Decoding with torch.no_grad(): @@ -449,18 +450,17 @@ def main(args): # Update input sequence with new token if pp_rank == first_pp_rank: _update_padded_sequence(padded_sequence, new_token, prompt_lengths) - for i in range(len(prompt_lengths)): - logger.info( - f"next submission: {padded_sequence[i, prompt_lengths[i]-4:prompt_lengths[i]+4]}" - ) + + # Display the decoding results # output formatted response via last pp group and tp rank 0 if pp_rank == last_pp_rank and tp_rank == 0: for i in range(len(prompt_lengths)): - logger.info(f"Prompt:{color.green} {prompt[i]} {color.reset}") + logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}") formatted_response = "".join(res[i]) - logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$") + logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n") + # Cleanup logger.info( f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" ) From 8b41ae2be604b45ed46b2572040e790167f6d84b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 12:52:27 -0700 Subject: [PATCH 3/7] ruff formatting --- dist_run.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dist_run.py b/dist_run.py index 04270d9ca..5e98ee0c9 100644 --- a/dist_run.py +++ b/dist_run.py @@ -31,7 +31,6 @@ get_num_params, GPUMemoryMonitor, ) -from distributed.verification_utils import find_cpu_tensors from torch.distributed.pipelining import PipelineStage, ScheduleGPipe from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs from torchchat.model import ModelArgs, Transformer @@ -219,7 +218,6 @@ def _update_padded_sequence( new_token: torch.Tensor, prompt_lengths: List[int], ) -> None: - for i in range(len(prompt_lengths)): prompt_lengths[i] += 1 padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0] From dfe6fe29ff1dbdeab16583bb48594f5804c40d65 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 12:59:38 -0700 Subject: [PATCH 4/7] use Ke's variable names for send/rcv --- dist_run.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dist_run.py b/dist_run.py index 5e98ee0c9..97bc55e56 100644 --- a/dist_run.py +++ b/dist_run.py @@ -373,8 +373,9 @@ def main(args): first_pp_rank = 0 last_pp_rank = pp_group_size - 1 - send_destination = dist.get_global_rank(pp_group, first_pp_rank) - recv_source = dist.get_global_rank(pp_group, last_pp_rank) + # Need these global ids due to the API definition of dist.send and recv + first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) + last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) # encode the prompt input_ids = _encode_strings( @@ -435,13 +436,13 @@ def main(args): if pp_rank == last_pp_rank and pp_rank != first_pp_rank: dist.send( new_token, - dst=send_destination, + dst=first_pp_rank_global_id, group=pp_group, ) elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: dist.recv( new_token, - src=recv_source, + src=last_pp_rank_global_id, group=pp_group, ) From 54eeecee523daad8f66595ad6d5dccc9ab31b852 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 15:43:19 -0700 Subject: [PATCH 5/7] add formatting exception for llama2 "".res --- dist_run.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dist_run.py b/dist_run.py index 97bc55e56..3d5947186 100644 --- a/dist_run.py +++ b/dist_run.py @@ -456,7 +456,12 @@ def main(args): if pp_rank == last_pp_rank and tp_rank == 0: for i in range(len(prompt_lengths)): logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}") - formatted_response = "".join(res[i]) + + # TODO: resolve issue with llama2-7b-chat model and "".join + if model_name != "llama2-7b-chat": + formatted_response = "".join(res[i]) + else: + formatted_response = " ".join(res[i]) logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n") # Cleanup From 39a974c6a026cf13f195bcb57197b5f480a803f1 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 15:50:42 -0700 Subject: [PATCH 6/7] fix prompt incrementing add formatting exception for llama2 "".res --- dist_run.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/dist_run.py b/dist_run.py index 3d5947186..e69626449 100644 --- a/dist_run.py +++ b/dist_run.py @@ -219,7 +219,6 @@ def _update_padded_sequence( prompt_lengths: List[int], ) -> None: for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0] # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") @@ -427,10 +426,6 @@ def main(args): [decode_results[i][0]], device=device ) # decode_results[i][0] - # increment prompt lengths for next token - for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 - # sendrecv between last and first ranks, only if: # first_pp_rank != last_pp_rank. if pp_rank == last_pp_rank and pp_rank != first_pp_rank: @@ -446,6 +441,10 @@ def main(args): group=pp_group, ) + # increment prompt lengths for next token + for i in range(len(prompt_lengths)): + prompt_lengths[i] += 1 + # Update input sequence with new token if pp_rank == first_pp_rank: _update_padded_sequence(padded_sequence, new_token, prompt_lengths) From 5abb4444c13b4d9a529003014492ab0f02a158f6 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 15 Sep 2024 15:56:20 -0700 Subject: [PATCH 7/7] revert prompt incrementing to pp=1 state --- dist_run.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dist_run.py b/dist_run.py index e69626449..4650e3df1 100644 --- a/dist_run.py +++ b/dist_run.py @@ -219,7 +219,7 @@ def _update_padded_sequence( prompt_lengths: List[int], ) -> None: for i in range(len(prompt_lengths)): - padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0] + padded_sequence[i, prompt_lengths[i]] = new_token[i, 0] # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") @@ -441,14 +441,14 @@ def main(args): group=pp_group, ) - # increment prompt lengths for next token - for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 - # Update input sequence with new token if pp_rank == first_pp_rank: _update_padded_sequence(padded_sequence, new_token, prompt_lengths) + # increment prompt lengths for next token + for i in range(len(prompt_lengths)): + prompt_lengths[i] += 1 + # Display the decoding results # output formatted response via last pp group and tp rank 0