diff --git a/dist_run.py b/dist_run.py index 8bcbd5e4d..fc580ea2a 100644 --- a/dist_run.py +++ b/dist_run.py @@ -187,8 +187,8 @@ def _create_padded_prompts( def _batch_decode_next_tokens( output: torch.Tensor, - prompt_lengths: List[int], tokenizer, + prompt_lengths: Optional[List[int]] = None, ) -> List[Tuple[int, str]]: """ Decode the next token for each prompt in the batch. @@ -201,7 +201,8 @@ def _batch_decode_next_tokens( results = [] for i in range(batch_size): - next_token_logits = output[i, prompt_lengths[i] - 1, :] + pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0 + next_token_logits = output[i, pos, :] # Argmax (deterministic) TODO: add temperature next_token = torch.argmax(next_token_logits, dim=-1) @@ -276,6 +277,10 @@ def main(args): tp_group_size = tp_group.size() logger.info(f"{pp_group_size=}, {tp_group_size=}") + # Convenience variables + first_pp_rank = 0 + last_pp_rank = pp_group_size - 1 + # Assuming same number of GPUs per node device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") @@ -293,29 +298,23 @@ def main(args): logger.info(f"Model: {model}") mbs = 1 # number of micro-batches - mb_size = 5 # micro-batch size + mb_size = 4 # micro-batch size batch_size = mbs * mb_size # total batch size - seqlen = 4096 # sequence length + seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension # Setup KV caches (after model distribution) # TODO: the setting below only works for 1 micro-batch case. To support # multiple micro-batches, we need the KV cache in the model to be aware of # the number of micro-batches and the current micro-batch index. - model.setup_caches(mb_size, seqlen) - - mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) - activation = torch.rand( - mb_size, seqlen, dim, device=device, dtype=model_dtype - ) - example_args = mb_ids if pp_rank == 0 else activation + model.setup_caches(mb_size, seqlen_prefill) # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") - with CUDATrackTime() as timer: _load_model_weights(model, distribution, device=device, model_config=config) + model.to(device) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" @@ -330,53 +329,47 @@ def main(args): ) # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen - input_pos = torch.arange(seqlen, device=device) + input_pos = torch.arange(seqlen_prefill, device=device) model.setup_input_pos(input_pos) model.eval() - logger.info(f"Creating pipeline stage {pp_rank=}, {pp_degree=}") - stage = PipelineStage( + # Helper function to get example inputs and outputs for the stages. + def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: + mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) + activation = torch.rand( + mb_size, seqlen, dim, device=device, dtype=model_dtype + ) + logits = torch.rand( + mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype + ) + example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) + example_outputs = (logits if pp_rank == last_pp_rank else activation,) + return example_inputs, example_outputs + + # Create prefill stage + logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_prefill) + prefill_stage = PipelineStage( model, pp_rank, pp_degree, device, - input_args=(example_args,), + input_args=example_inputs, + output_args=example_outputs, group=pp_group, ) + # create schedule + prefill_schedule = ScheduleGPipe(prefill_stage, mbs) prompt = [ - "What is snow?", - "Where does Santa Claus live?", - "What is PyTorch?", - "Write a poem about the beauty of the night sky.", - "What is the capital of France, Germany and Switzerland?", - ] - - """ - "What is the capital of France?", - "What is your name?", - "What is the capital of Japan?", - "When is Christmas?", - "Where does Santa Claus live?", - "What is the capital of the United States?", - "What is the capital of China?", - "What is the capital of Russia?", - "What is PyTorch?", - "What is the capital of India?", - "What is an LLM?", - "What is the capital of Brazil?", - "What is the capital of Mexico?", - "What is the capital of Argentina?", - "What is the capital of Canada?", + "What is a computer?", + "Where does Santa live?", + "Who is Abraham Lincoln?", + "How are models trained?", ] - """ start_pos = 0 - # pipeline comms setup - first_pp_rank = 0 - last_pp_rank = pp_group_size - 1 - # Need these global ids due to the API definition of dist.send and recv first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) @@ -388,15 +381,14 @@ def main(args): # create a padded tensor for the input prompt padded_sequence, prompt_lengths = _create_padded_prompts( - input_ids, tokenizer, seqlen, start_pos, device + input_ids, tokenizer, seqlen_prefill, start_pos, device ) - - # create schedule - schedule = ScheduleGPipe(stage, mbs) + # TODO: figure out how to set input_pos for each prompt in the batch then we + # can remove this limitation. + s = set(prompt_lengths) + assert len(s) == 1, f"prompt_lengths should be the same, got {s}" # with CUDATrackTime() as timer: - first_pp_rank = 0 - last_pp_rank = pp_group_size - 1 # Need these global ids due to the API definition of dist.send and recv first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) @@ -408,25 +400,87 @@ def main(args): res = [[] for _ in range(total_prompts)] num_tokens = 40 + # Prefill phase + # Run context input through pipeline, in 1 step + with torch.no_grad(): + if pp_rank == first_pp_rank: + output = prefill_schedule.step(padded_sequence) + elif pp_rank == last_pp_rank: + output = prefill_schedule.step() + else: # middle pp ranks + prefill_schedule.step() + + # Decode the output -- first generated token + if pp_rank == last_pp_rank: + decode_results = _batch_decode_next_tokens( + output=output, + tokenizer=tokenizer, + prompt_lengths=prompt_lengths, + ) + for i in range(len(decode_results)): + new_token[i, 0] = torch.tensor( + [decode_results[i][0]], device=device + ) # token_id in int form + if tp_rank == 0: + logger.info( + f"{color.green} {'* Prefill *'} " + f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" + ) + + # seqlen = 1 now + seqlen_decode = 1 + input_pos = torch.tensor([prompt_lengths[0]], device=device) + model.setup_input_pos(input_pos) + + # Create decode stage + logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_decode) + decode_stage = PipelineStage( + model, + pp_rank, + pp_degree, + device, + input_args=example_inputs, + output_args=example_outputs, + group=pp_group, + ) + # create schedule + decode_schedule = ScheduleGPipe(decode_stage, mbs) + # Decoding with torch.no_grad(): - for step in range(num_tokens): + for step in range(num_tokens - 1): + # sendrecv between last and first ranks, only if: + # first_pp_rank != last_pp_rank. + if pp_rank == last_pp_rank and pp_rank != first_pp_rank: + dist.send( + new_token, + dst=first_pp_rank_global_id, + group=pp_group, + ) + elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: + dist.recv( + new_token, + src=last_pp_rank_global_id, + group=pp_group, + ) + # Run data through pipeline if pp_rank == first_pp_rank: - output = schedule.step(padded_sequence) + output = decode_schedule.step(new_token) elif pp_rank == last_pp_rank: - output = schedule.step() + output = decode_schedule.step() else: # middle pp ranks - schedule.step() + decode_schedule.step() # Decode the output if pp_rank == last_pp_rank: decode_results = _batch_decode_next_tokens( - output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer + output=output, tokenizer=tokenizer ) if tp_rank == 0: logger.info( - f"{color.green} {'Prefill' if step == 0 else '* Decode *'} " + f"{color.green} {'* Decode *'} " f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" ) # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] @@ -436,28 +490,8 @@ def main(args): [decode_results[i][0]], device=device ) # decode_results[i][0] - # sendrecv between last and first ranks, only if: - # first_pp_rank != last_pp_rank. - if pp_rank == last_pp_rank and pp_rank != first_pp_rank: - dist.send( - new_token, - dst=first_pp_rank_global_id, - group=pp_group, - ) - elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: - dist.recv( - new_token, - src=last_pp_rank_global_id, - group=pp_group, - ) - - # Update input sequence with new token - if pp_rank == first_pp_rank: - _update_padded_sequence(padded_sequence, new_token, prompt_lengths) - - # increment prompt lengths for next token - for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 + input_pos += 1 + model.setup_input_pos(input_pos) # Display the decoding results