This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Add lanes to KV cache #1174
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,13 +273,11 @@ def main(args): | |
| pp_rank = pp_mesh.get_local_rank() | ||
| tp_group = tp_mesh.get_group() | ||
| pp_group = pp_mesh.get_group() | ||
| pp_group_size = pp_group.size() | ||
| tp_group_size = tp_group.size() | ||
| logger.info(f"{pp_group_size=}, {tp_group_size=}") | ||
| logger.info(f"{pp_degree=}, {tp_degree=}") | ||
|
|
||
| # Convenience variables | ||
| first_pp_rank = 0 | ||
| last_pp_rank = pp_group_size - 1 | ||
| last_pp_rank = pp_degree - 1 | ||
|
|
||
| # Assuming same number of GPUs per node | ||
| device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") | ||
|
|
@@ -297,18 +295,22 @@ def main(args): | |
| if rank == 0: | ||
| logger.info(f"Model: {model}") | ||
|
|
||
| mbs = 1 # number of micro-batches | ||
| mb_size = 4 # micro-batch size | ||
| batch_size = mbs * mb_size # total batch size | ||
|
|
||
| # Batch size. Since we push batches dynamically through the pipeline rather | ||
| # than chunking them, this is effectively micro-batch size in pipeline | ||
| # sense. Thus it is interchangeable with micro-batch size below. | ||
| batch_size = 4 | ||
| seqlen_prefill = 1024 # sequence length | ||
| dim = 4096 # embedding dimension | ||
|
|
||
| # Setup KV caches (after model distribution) | ||
| # TODO: the setting below only works for 1 micro-batch case. To support | ||
| # multiple micro-batches, we need the KV cache in the model to be aware of | ||
| # the number of micro-batches and the current micro-batch index. | ||
| model.setup_caches(mb_size, seqlen_prefill) | ||
| # The number of cache lanes is the same as the maximum number of | ||
| # micro-batches that can be "in flight" in parallel -- imagine each | ||
| # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. | ||
| # When decoding is done for certain micro-batches, we can reuse the KV cache | ||
| # lanes. | ||
| # TODO: bump up the lane count | ||
| pipeline_lanes = 1 | ||
| model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) | ||
|
|
||
| # Load weights | ||
| logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
|
|
@@ -317,7 +319,7 @@ def main(args): | |
| model.to(device) | ||
|
|
||
| logger.info( | ||
| f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" | ||
| f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
| ) | ||
|
|
||
| # info on stage size and params | ||
|
|
@@ -330,17 +332,16 @@ def main(args): | |
|
|
||
| # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen | ||
| input_pos = torch.arange(seqlen_prefill, device=device) | ||
| model.setup_input_pos(input_pos) | ||
| model.eval() | ||
|
|
||
| # Helper function to get example inputs and outputs for the stages. | ||
| def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) | ||
| mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) | ||
| activation = torch.rand( | ||
| mb_size, seqlen, dim, device=device, dtype=model_dtype | ||
| batch_size, seqlen, dim, device=device, dtype=model_dtype | ||
| ) | ||
| logits = torch.rand( | ||
| mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype | ||
| batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype | ||
| ) | ||
| example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) | ||
| example_outputs = (logits if pp_rank == last_pp_rank else activation,) | ||
|
|
@@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| output_args=example_outputs, | ||
| group=pp_group, | ||
| ) | ||
| # create schedule | ||
| prefill_schedule = ScheduleGPipe(prefill_stage, mbs) | ||
|
|
||
| # Create schedule | ||
| # Number of micro-batches for the schedule is 1, because each step() call we | ||
| # only push 1 micro-batch into the pipeline. But we can continuously push | ||
| # new micro-batches into the pipeline as they arrive, achieving same | ||
| # pipelining effect. | ||
| prefiller = ScheduleGPipe(prefill_stage, 1) | ||
|
|
||
| prompt = [ | ||
| "What is a computer?", | ||
|
|
@@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| s = set(prompt_lengths) | ||
| assert len(s) == 1, f"prompt_lengths should be the same, got {s}" | ||
|
|
||
| # with CUDATrackTime() as timer: | ||
| # Need these global ids due to the API definition of dist.send and recv | ||
| first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) | ||
| last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) | ||
|
|
@@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| num_tokens = 40 | ||
|
|
||
| # Prefill phase | ||
| # Run context input through pipeline, in 1 step | ||
| with torch.no_grad(): | ||
| # Run context input through pipeline | ||
| # TODO: we need to pass `input_pos` and `cache_lane` to each stage. | ||
| lane = 0 | ||
| kwargs = {"input_pos": input_pos, "cache_lane": lane} | ||
| with torch.no_grad(), CUDATrackTime() as timer: | ||
| if pp_rank == first_pp_rank: | ||
| output = prefill_schedule.step(padded_sequence) | ||
| output = prefiller.step(padded_sequence, **kwargs) | ||
| elif pp_rank == last_pp_rank: | ||
| output = prefill_schedule.step() | ||
| output = prefiller.step(**kwargs) | ||
| else: # middle pp ranks | ||
| prefill_schedule.step() | ||
| prefiller.step(**kwargs) | ||
|
|
||
| logger.info( | ||
| f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
| ) | ||
|
|
||
| # Decode the output -- first generated token | ||
| if pp_rank == last_pp_rank: | ||
|
|
@@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # seqlen = 1 now | ||
| seqlen_decode = 1 | ||
| input_pos = torch.tensor([prompt_lengths[0]], device=device) | ||
| model.setup_input_pos(input_pos) | ||
|
|
||
| # Create decode stage | ||
| logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") | ||
|
|
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| group=pp_group, | ||
| ) | ||
| # create schedule | ||
| decode_schedule = ScheduleGPipe(decode_stage, mbs) | ||
| decorder = ScheduleGPipe(decode_stage, 1) | ||
|
|
||
| # Decoding | ||
| with torch.no_grad(): | ||
| with torch.no_grad(), CUDATrackTime() as timer: | ||
| for step in range(num_tokens - 1): | ||
| kwargs = {"input_pos": input_pos, "cache_lane": lane} | ||
| # sendrecv between last and first ranks, only if: | ||
| # first_pp_rank != last_pp_rank. | ||
| if pp_rank == last_pp_rank and pp_rank != first_pp_rank: | ||
|
|
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
|
|
||
| # Run data through pipeline | ||
| if pp_rank == first_pp_rank: | ||
| output = decode_schedule.step(new_token) | ||
| output = decorder.step(new_token, **kwargs) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, syntax error - this should be 'decoder' and not 'decorder'. |
||
| elif pp_rank == last_pp_rank: | ||
| output = decode_schedule.step() | ||
| output = decorder.step(**kwargs) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, syntax error - this should be 'decoder' and not 'decorder'. |
||
| else: # middle pp ranks | ||
| decode_schedule.step() | ||
| decorder.step(**kwargs) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. last one, syntax error - this should be 'decoder' and not 'decorder'. |
||
|
|
||
| # Decode the output | ||
| if pp_rank == last_pp_rank: | ||
|
|
@@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| ) # decode_results[i][0] | ||
|
|
||
| input_pos += 1 | ||
| model.setup_input_pos(input_pos) | ||
|
|
||
| logger.info( | ||
| f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
| ) | ||
|
|
||
| # Display the decoding results | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax error - this should be 'decoder' and not 'decorder'.