diff --git a/dist_run.py b/dist_run.py index 5282cb6ad..a46dd81f4 100644 --- a/dist_run.py +++ b/dist_run.py @@ -135,8 +135,8 @@ def _load_model_weights(stage_module, distribution, device, model_config): def _encode_strings( strings: List[str], tokenizer, - bos: bool = True, - device: torch.device = "cuda:0", + bos: bool, + device: torch.device, dtype=torch.int64, ) -> List[torch.Tensor]: """Encode a list of prompt strings into a list of tensor token ids.""" @@ -216,13 +216,13 @@ def _batch_decode_next_tokens( def _update_padded_sequence( padded_sequence: torch.Tensor, - x_recv: torch.Tensor, - res, + new_token: torch.Tensor, prompt_lengths: List[int], ) -> None: + # TODO: this is a hacky way to update the padded sequence: when there is + # more than one prompt, the for loop and the assignment is incompatible. for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 - padded_sequence[i, prompt_lengths[i] - 1] = x_recv + padded_sequence[i, prompt_lengths[i]] = new_token def _cleanup(): @@ -267,19 +267,15 @@ def main(args): device_mesh = _create_device_mesh(mesh_dimensions) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] + logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") + tp_rank = tp_mesh.get_local_rank() pp_rank = pp_mesh.get_local_rank() tp_group = tp_mesh.get_group() pp_group = pp_mesh.get_group() - - logger.info(f"review: {pp_group=}, {tp_group= }") - - logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n") - # TODO - this assumes 1D mesh, need to update for 2D+ mesh - pp_group_size = pp_mesh.size() - tp_group_size = tp_mesh.size() - - logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}") + pp_group_size = pp_group.size() + tp_group_size = tp_group.size() + logger.info(f"{pp_group_size=}, {tp_group_size=}") # Assuming same number of GPUs per node device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") @@ -316,7 +312,7 @@ def main(args): logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - _load_model_weights(model, hf_model_name, device=device, model_config=config) + _load_model_weights(model, distribution, device=device, model_config=config) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" @@ -327,7 +323,7 @@ def main(args): stage_size_formatted = bytes_to_readable(stage_size) stage_num_params = get_num_params(model) logger.info( - f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n" + f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" ) # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen @@ -342,15 +338,9 @@ def main(args): pp_degree, device, input_args=(example_args,), - group=pp_mesh.get_group(), + group=pp_group, ) - # this check confirms that there are no cpu tensors in the model..we expect this to be true. - cpu_tensors = find_cpu_tensors(stage.submod) - # logger.info(f"Found {len(cpu_tensors)} cpu tensors: {cpu_tensors}") - if len(cpu_tensors) > 0: - raise ValueError("Found cpu tensors in stage") - prompt = [ "What is snow?", ] @@ -374,7 +364,6 @@ def main(args): ] """ - start_pos = 0 # encode the prompt @@ -388,88 +377,74 @@ def main(args): input_ids, tokenizer, seqlen, start_pos, device ) logger.info(f"{prompt_lengths=}") - logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}") - if len(prompt_lengths) > 1: - logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}") + # create schedule schedule = ScheduleGPipe(stage, mbs) - logger.info(f"Created schedule: {schedule}") # with CUDATrackTime() as timer: - first_pp_group = 0 - last_pp_group = pp_group_size - 1 - - x_recv = torch.zeros(1, device=device, dtype=torch.int64) - logger.info(f"{x_recv.shape=}") + first_pp_rank = 0 + last_pp_rank = pp_group_size - 1 - last_global_rank = world_size - 1 + # New token generated each iteration + new_token = torch.zeros(1, device=device, dtype=torch.int64) res = [] - dst = None - src = None - - if pp_rank == last_pp_group: - dst = dist.get_global_rank(pp_group, 0) - elif pp_rank == 0: - src = dist.get_global_rank(pp_group, last_pp_group) - - # Decoding num_tokens = 40 + # Decoding with torch.no_grad(): for step in range(num_tokens): - # first - if pp_rank == 0: - schedule.step(padded_sequence) - # only receive if not last step - if step < num_tokens - 1: - dist.recv( - x_recv, - src, - group=pp_group, - ) - _update_padded_sequence( - padded_sequence, x_recv, res, prompt_lengths - ) - - # last - elif pp_rank == last_pp_group: + # Run data through pipeline + if pp_rank == first_pp_rank: + output = schedule.step(padded_sequence) + elif pp_rank == last_pp_rank: output = schedule.step() - # need to decode the output + else: # middle pp ranks + schedule.step() + + # Decode the output + if pp_rank == last_pp_rank: decode_results = _batch_decode_next_tokens( output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer ) if tp_rank == 0: logger.info( - f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}" + f"{color.green} {'Prefill' if step == 0 else '* Decode *'} " + f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" ) - - next_token = torch.tensor([decode_results[0][0]], device=device) + # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] + new_token = torch.tensor([decode_results[0][0]], device=device) res.append(decode_results[0][1]) - # increment prompt lengths for next token - for i in range(len(prompt_lengths)): - prompt_lengths[i] += 1 - # logger.info( - # f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}" - # ) - - # only send if not last step - if step < (num_tokens - 1): - dist.send( - next_token, - dst, - pp_group, - ) + # sendrecv between last and first ranks, only if: + # first_pp_rank != last_pp_rank. + if pp_rank == last_pp_rank and pp_rank != first_pp_rank: + dist.send( + new_token, + dst=dist.get_global_rank(pp_group, first_pp_rank), + group=pp_group, + ) + elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: + dist.recv( + new_token, + src=dist.get_global_rank(pp_group, last_pp_rank), + group=pp_group, + ) - # middle pp ranks - else: - schedule.step() + # Update input sequence with new token + if pp_rank == first_pp_rank: + _update_padded_sequence( + padded_sequence, new_token, prompt_lengths + ) + + # increment prompt lengths for next token + for i in range(len(prompt_lengths)): + prompt_lengths[i] += 1 # output formatted response via last pp group and tp rank 0 - if pp_rank == last_pp_group and tp_rank == 0: - logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}") - formatted_response = "".join(res) - logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$") + if pp_rank == last_pp_rank and tp_rank == 0: + logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}") + formatted_response = " ".join(res) + logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$") logger.info( f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"