This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,8 +135,8 @@ def _load_model_weights(stage_module, distribution, device, model_config): | |
| def _encode_strings( | ||
| strings: List[str], | ||
| tokenizer, | ||
| bos: bool = True, | ||
| device: torch.device = "cuda:0", | ||
| bos: bool, | ||
| device: torch.device, | ||
| dtype=torch.int64, | ||
| ) -> List[torch.Tensor]: | ||
| """Encode a list of prompt strings into a list of tensor token ids.""" | ||
|
|
@@ -216,13 +216,13 @@ def _batch_decode_next_tokens( | |
|
|
||
| def _update_padded_sequence( | ||
| padded_sequence: torch.Tensor, | ||
| x_recv: torch.Tensor, | ||
| res, | ||
| new_token: torch.Tensor, | ||
| prompt_lengths: List[int], | ||
| ) -> None: | ||
| # TODO: this is a hacky way to update the padded sequence: when there is | ||
| # more than one prompt, the for loop and the assignment is incompatible. | ||
| for i in range(len(prompt_lengths)): | ||
| prompt_lengths[i] += 1 | ||
| padded_sequence[i, prompt_lengths[i] - 1] = x_recv | ||
| padded_sequence[i, prompt_lengths[i]] = new_token | ||
|
|
||
|
|
||
| def _cleanup(): | ||
|
|
@@ -267,19 +267,15 @@ def main(args): | |
| device_mesh = _create_device_mesh(mesh_dimensions) | ||
| tp_mesh = device_mesh["tp"] | ||
| pp_mesh = device_mesh["pp"] | ||
| logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") | ||
|
|
||
| tp_rank = tp_mesh.get_local_rank() | ||
| pp_rank = pp_mesh.get_local_rank() | ||
| tp_group = tp_mesh.get_group() | ||
| pp_group = pp_mesh.get_group() | ||
|
|
||
| logger.info(f"review: {pp_group=}, {tp_group= }") | ||
|
|
||
| logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n") | ||
| # TODO - this assumes 1D mesh, need to update for 2D+ mesh | ||
| pp_group_size = pp_mesh.size() | ||
| tp_group_size = tp_mesh.size() | ||
|
|
||
| logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}") | ||
| pp_group_size = pp_group.size() | ||
| tp_group_size = tp_group.size() | ||
| logger.info(f"{pp_group_size=}, {tp_group_size=}") | ||
|
|
||
| # Assuming same number of GPUs per node | ||
| device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") | ||
|
|
@@ -316,7 +312,7 @@ def main(args): | |
| logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
|
|
||
| with CUDATrackTime() as timer: | ||
| _load_model_weights(model, hf_model_name, device=device, model_config=config) | ||
| _load_model_weights(model, distribution, device=device, model_config=config) | ||
|
|
||
| logger.info( | ||
| f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" | ||
|
|
@@ -327,7 +323,7 @@ def main(args): | |
| stage_size_formatted = bytes_to_readable(stage_size) | ||
| stage_num_params = get_num_params(model) | ||
| logger.info( | ||
| f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n" | ||
| f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" | ||
| ) | ||
|
|
||
| # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen | ||
|
|
@@ -342,15 +338,9 @@ def main(args): | |
| pp_degree, | ||
| device, | ||
| input_args=(example_args,), | ||
| group=pp_mesh.get_group(), | ||
| group=pp_group, | ||
| ) | ||
|
|
||
| # this check confirms that there are no cpu tensors in the model..we expect this to be true. | ||
| cpu_tensors = find_cpu_tensors(stage.submod) | ||
| # logger.info(f"Found {len(cpu_tensors)} cpu tensors: {cpu_tensors}") | ||
| if len(cpu_tensors) > 0: | ||
| raise ValueError("Found cpu tensors in stage") | ||
|
|
||
| prompt = [ | ||
| "What is snow?", | ||
| ] | ||
|
|
@@ -374,7 +364,6 @@ def main(args): | |
| ] | ||
| """ | ||
|
|
||
|
|
||
| start_pos = 0 | ||
|
|
||
| # encode the prompt | ||
|
|
@@ -388,88 +377,74 @@ def main(args): | |
| input_ids, tokenizer, seqlen, start_pos, device | ||
| ) | ||
| logger.info(f"{prompt_lengths=}") | ||
| logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}") | ||
| if len(prompt_lengths) > 1: | ||
| logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}") | ||
|
|
||
| # create schedule | ||
| schedule = ScheduleGPipe(stage, mbs) | ||
| logger.info(f"Created schedule: {schedule}") | ||
|
|
||
| # with CUDATrackTime() as timer: | ||
| first_pp_group = 0 | ||
| last_pp_group = pp_group_size - 1 | ||
|
|
||
| x_recv = torch.zeros(1, device=device, dtype=torch.int64) | ||
| logger.info(f"{x_recv.shape=}") | ||
| first_pp_rank = 0 | ||
| last_pp_rank = pp_group_size - 1 | ||
|
|
||
| last_global_rank = world_size - 1 | ||
| # New token generated each iteration | ||
| new_token = torch.zeros(1, device=device, dtype=torch.int64) | ||
| res = [] | ||
| dst = None | ||
| src = None | ||
|
|
||
| if pp_rank == last_pp_group: | ||
| dst = dist.get_global_rank(pp_group, 0) | ||
| elif pp_rank == 0: | ||
| src = dist.get_global_rank(pp_group, last_pp_group) | ||
|
|
||
| # Decoding | ||
| num_tokens = 40 | ||
|
|
||
| # Decoding | ||
| with torch.no_grad(): | ||
| for step in range(num_tokens): | ||
| # first | ||
| if pp_rank == 0: | ||
| schedule.step(padded_sequence) | ||
| # only receive if not last step | ||
| if step < num_tokens - 1: | ||
| dist.recv( | ||
| x_recv, | ||
| src, | ||
| group=pp_group, | ||
| ) | ||
| _update_padded_sequence( | ||
| padded_sequence, x_recv, res, prompt_lengths | ||
| ) | ||
|
|
||
| # last | ||
| elif pp_rank == last_pp_group: | ||
| # Run data through pipeline | ||
| if pp_rank == first_pp_rank: | ||
| output = schedule.step(padded_sequence) | ||
| elif pp_rank == last_pp_rank: | ||
| output = schedule.step() | ||
| # need to decode the output | ||
| else: # middle pp ranks | ||
| schedule.step() | ||
|
|
||
| # Decode the output | ||
| if pp_rank == last_pp_rank: | ||
| decode_results = _batch_decode_next_tokens( | ||
| output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer | ||
| ) | ||
| if tp_rank == 0: | ||
| logger.info( | ||
| f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}" | ||
| f"{color.green} {'Prefill' if step == 0 else '* Decode *'} " | ||
| f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" | ||
| ) | ||
|
|
||
| next_token = torch.tensor([decode_results[0][0]], device=device) | ||
| # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] | ||
| new_token = torch.tensor([decode_results[0][0]], device=device) | ||
| res.append(decode_results[0][1]) | ||
|
|
||
| # increment prompt lengths for next token | ||
| for i in range(len(prompt_lengths)): | ||
| prompt_lengths[i] += 1 | ||
| # logger.info( | ||
| # f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}" | ||
| # ) | ||
|
|
||
| # only send if not last step | ||
| if step < (num_tokens - 1): | ||
| dist.send( | ||
| next_token, | ||
| dst, | ||
| pp_group, | ||
| ) | ||
| # sendrecv between last and first ranks, only if: | ||
| # first_pp_rank != last_pp_rank. | ||
| if pp_rank == last_pp_rank and pp_rank != first_pp_rank: | ||
| dist.send( | ||
| new_token, | ||
| dst=dist.get_global_rank(pp_group, first_pp_rank), | ||
| group=pp_group, | ||
| ) | ||
| elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: | ||
| dist.recv( | ||
| new_token, | ||
| src=dist.get_global_rank(pp_group, last_pp_rank), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above w/dst, why do we call this api over and over to get src within the loop instead of once out of the loop? |
||
| group=pp_group, | ||
| ) | ||
|
|
||
| # middle pp ranks | ||
| else: | ||
| schedule.step() | ||
| # Update input sequence with new token | ||
| if pp_rank == first_pp_rank: | ||
| _update_padded_sequence( | ||
| padded_sequence, new_token, prompt_lengths | ||
| ) | ||
|
|
||
| # increment prompt lengths for next token | ||
| for i in range(len(prompt_lengths)): | ||
| prompt_lengths[i] += 1 | ||
|
|
||
| # output formatted response via last pp group and tp rank 0 | ||
| if pp_rank == last_pp_group and tp_rank == 0: | ||
| logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}") | ||
| formatted_response = "".join(res) | ||
| logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$") | ||
| if pp_rank == last_pp_rank and tp_rank == 0: | ||
| logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}") | ||
| formatted_response = " ".join(res) | ||
| logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$") | ||
|
|
||
| logger.info( | ||
| f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst and src shouldn't be getting recreated every iter, since they don't change on a per iter basis.
This is why I had moved them out of the loop previously.
Not sure how expensive the dist.get_global_rank is but no need imo to be calling it over and over here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Will upload a PR to improve it. Thanks!