perf(pipeline): implement auto-partition algorithm#2113
perf(pipeline): implement auto-partition algorithm#2113TXacs wants to merge 11 commits intopytorch:mainfrom
Conversation
1. Improve pipeline performance 2. Auto partition modules
|
Hi @TXacs! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks. Is it true that the only "real" deltas are
- autopipe.cpp
- pipeline_parallel.py
- profiler.py
There was a problem hiding this comment.
This looks interesting -- how much benefit you'd get from having a c++ implementation, compared with a python one?
Yes,actually, |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
|
||
| parts = pipeline( | ||
| mflops_list, | ||
| [i * 3 for i in mflops_list], # Assume backward is 3x forward |
There was a problem hiding this comment.
Similar to the computation time, in which the backward computation time is twice that of the forward computation time, by default, we assume that the floating-point operations (FLOPs) required for backward computation are twice those for forward computation. During the modeling process, we take the recomputation technique into account. This technique inserts an additional forward computation before the backward computation. Consequently, we set the default FLOPs for backward computation to be three times those of forward computation.
| # Profile each layer's FLOPS | ||
| mflops_list = [] | ||
| for _, layer in enumerate(model): | ||
| prof = FlopsProfiler(layer) |
There was a problem hiding this comment.
I guess the FlopsProfiler does not estimate the backward flops?
There was a problem hiding this comment.
You are correct that the FlopsProfiler does not estimate backward flops. Instead, we use a heuristic rule by default: given the adoption of recomputation techniques in the backward computation, backward flops are set to three times those of the forward computation.
There was a problem hiding this comment.
curious why three times? IIUC the convention was two times.
There was a problem hiding this comment.
During the modeling process, we employ the recomputation technique, which introduces an additional forward pass during backpropagation. To control the computational overhead of backpropagation, we can adjust the implementation by setting a scaling factor—typically 2 or 3—depending on whether recomputation is enabled. This adjustment is orthogonal to the pipeline partitioning algorithm.
There was a problem hiding this comment.
I guess the
FlopsProfilerdoes not estimate the backward flops?
We have completely rebuilt the FLOPs acquisition logic—reverse FLOPs for each layer can now be computed according to the ACConfig strategy, and the AutoPipe algorithm has been re-implemented in pure Python.
There was a problem hiding this comment.
Following the previous review, we have further refined this PR with the following key updates:
- Refactored the core algorithm from C++ to pure Python, enhancing maintainability and better aligning with the PyTorch ecosystem.
- Integrated PyTorch's official torch.utils.flop_counter.FlopCounterMode, replacing the third-party profiler to improve integration and reduce external dependencies.
- Refined the FLOPs calculation logic to dynamically determine backward pass costs based on the activation checkpointing configuration, moving beyond a fixed multiplier for greater accuracy.
We welcome any additional suggestions to move this forward. Thanks again for your review.
…figuration Based on TorchTitan's configuration that only applies recomputation to transformer layers, the algorithm has been modified. When using ac_config.selective_ac_option = 'nlayers', performance shows a 3% to 17% improvement compared to the previous algorithm that all layers defaulted to 3× forward FLOPs.
1. Re-implemented AutoPipe logic in Python, removed legacy C++ codebase 2. Replaced DeepSpeed profiler with torch.utils.flop_counter.FlopCounterMode 3. Added layer-wise reverse-FLOPs computation based on ACConfig‘s recomputation policy
|
|
||
| ### Compile | ||
|
|
||
| First, we need to compile `autopipe.cpp`. |
There was a problem hiding this comment.
is this comment outdated? i do not see any .cpp files in the PR anymore.
There was a problem hiding this comment.
Yes, this comment is outdated. The PR has migrated AutoPipe logic from C++ to pure Python, so there are no .cpp files involved anymore. We should remove or update this line to reflect the current implementation. Sorry for the confusion caused by the outdated documentation!
| ) | ||
| seq_modules = _build_module_for_profile(copied_model, flatten_module_names) | ||
|
|
||
| parts = autopipe_partition(seq_modules, parallel_dims.pp, job_config) |
There was a problem hiding this comment.
why do we do the autopipe_partition after the splitting is already done? I propose there should be 3 different ways of doing things:
- layer-wise static partitioning (default)
- config based (user sets layers per stage)
- autopipe partition (this new implementation)
The option to choose should be set by configuration. Further, if the autopipe way is determined to be strictly better than it should replace the default.
| logger.info(f"Autopipe partitioning with mflops: {mflops_fwd}, {mflops_bwd}") | ||
|
|
||
| # Partition layers by forward and backward flops | ||
| parts = pipeline( |
There was a problem hiding this comment.
nit: this pipeline name conflicts with naming provided in pytorch under torch.distributed.pipelining so maybe should have a different name
| tokenizer = build_hf_tokenizer(job_config) | ||
|
|
||
| # build dataloader | ||
| dataloader = build_text_dataloader( |
There was a problem hiding this comment.
it's not ideal that we have to create a new dataloader for each rank. Can we not use dummy data as the example inputs instead?
There was a problem hiding this comment.
Thanks for reviewing the code. We considered that different models have different data formats, which is why using the original data loader can help us avoid many issues.
There was a problem hiding this comment.
why data format matters for FLOP counting? dummy data should be sufficient as long as it has the correct shape and dtype?
|
|
||
| for layer_idx, layer in enumerate(model): | ||
| # forward | ||
| with FlopCounterMode(display=False) as mode: |
There was a problem hiding this comment.
Hm, this may be an inherent limitation. IIUC FlopCounterMode only counts flops for registered pytorch ops. So it would not capture things like FlexAttention which uses fused kernels, it may also not cover torch.compile related optimizations. Do you happen to know @sanketpurandare? How did you do things for your runtime estimators?
There was a problem hiding this comment.
Thanks for the insightful comment! You're absolutely right. FlopCounterMode has limitations with fused kernels and torch.compile related optimizations. However, our approach is modular and extensible, so it can integrate more accurate cost models as they emerge. We treat FLOPs as a baseline, not the final answer. We're open to incorporating better estimators. Would love to hear your thoughts on how you bridge this gap!
H-Huang
left a comment
There was a problem hiding this comment.
I think the PR should be split up into two parts:
-
move autopipe.py to pytorch/pytorch under torch/distributed/pipelining/_utils.py
-
Update pipeline_parallel.py to import this function from pytorch and perform the split (based on configuration)
The rest of the changes in this PR can be removed.
| return h | ||
|
|
||
| # C++ VectorEqual | ||
| def _eq(a, b): |
There was a problem hiding this comment.
Thank you very much for your careful review! This function is currently unused, and its presence was an oversight on our part. We sincerely apologize for the confusion and appreciate you pointing it out. We’ll remove the unused function in the next update.
|
|
||
| # Hash func: C++ VectorHash | ||
| @functools.cache | ||
| def _hash(p): |
There was a problem hiding this comment.
python has its own hash standard library you should use
There was a problem hiding this comment.
Thank you for the feedback! We will update the code accordingly to use the official implementation.
| import functools | ||
| from collections import deque | ||
|
|
||
| COMM_OVERHEAD = 0 |
There was a problem hiding this comment.
We set the communication overhead to 0 because the asynchronous communication is fully overlapped with computation. This effectively hides the communication latency, making the observed overhead negligible in our experimental setup.
| ): | ||
| """C++ calculate_stage_times""" | ||
| num_stages = len(partition) | ||
| num_micro = num_stages * 2 |
There was a problem hiding this comment.
num_stages and num_micro should both be hyperparameters, this assumes num_micro is always 2 * stages?
There was a problem hiding this comment.
The num_stages used in this context is a local variable, while the global num_stages is indeed already treated as a hyperparameter. As for num_micro, we currently default it to 2 * stages because this configuration ensures the pipeline naturally forms three distinct phases: warmup, steady, and cooldown. We agree that num_micro should be configurable as a hyperparameter to support more flexible pipeline configurations. We’ll update the code accordingly.
| return flush, critical | ||
|
|
||
|
|
||
| # ---------- 最优分区搜索 ---------- |
There was a problem hiding this comment.
can the chinese comments be translated to english for consistency?
There was a problem hiding this comment.
We’ll translate all Chinese comments into English to maintain consistency across the codebase. Thanks for pointing this out!
Thank you for the helpful suggestion! Could you kindly clarify the expected steps for moving autopipe.py into the PyTorch repository? |
…h function with Python's built-in hash();3. Translate Chinese to English
You will need to clean up the file to be more generalizable
I don't think these need to all be solved as once, but if you can get a proof of concept upstreamed to PyTorch that is configurable (remove some of these hardcoded assumptions), then torchtitan can just perform the plumbing to import this and use it. So concretely, you can upstream the autopipe algorithm to |
| tokenizer = build_hf_tokenizer(job_config) | ||
|
|
||
| # build dataloader | ||
| dataloader = build_text_dataloader( |
There was a problem hiding this comment.
why data format matters for FLOP counting? dummy data should be sufficient as long as it has the correct shape and dtype?
| nn.Sequential: A sequential module containing the specified modules | ||
| """ | ||
| module_seq = nn.Sequential() | ||
| base_model = copy.deepcopy(model) |
There was a problem hiding this comment.
this seems redundant to L149?
| def _build_stage_from_modules( | ||
| stage_idx: int, module_names: list[str], num_stages: int | ||
| ) -> tuple[PipelineStage, nn.Module]: | ||
| model = copy.deepcopy(whole_model) | ||
|
|
||
| # Create a set of modules to keep for faster lookup | ||
| modules_to_keep = set(module_names) | ||
| for module_name, module_value in model.named_children(): | ||
| # Handle layer-like structures (e.g., "layers.0", "layers.1") | ||
| if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): | ||
| layers_to_keep = { | ||
| name.split(".", 1)[1] | ||
| for name in modules_to_keep | ||
| if name.startswith(f"{module_name}.") | ||
| } | ||
| if layers_to_keep: | ||
| # Keep only specified layers | ||
| if isinstance(module_value, nn.ModuleDict): | ||
| for layer_name in list(module_value.keys()): | ||
| if layer_name not in layers_to_keep: | ||
| del module_value[layer_name] | ||
| elif isinstance(module_value, nn.ModuleList): | ||
| indices_to_keep = { | ||
| int(idx) for idx in layers_to_keep if idx.isdigit() | ||
| } | ||
| new_layers = nn.ModuleList( | ||
| [ | ||
| layer | ||
| for i, layer in enumerate(module_value) | ||
| if i in indices_to_keep | ||
| ] | ||
| ) | ||
| setattr(model, module_name, new_layers) | ||
| else: | ||
| # No layers from this structure needed, set to empty structure | ||
| if isinstance(module_value, nn.ModuleDict): | ||
| setattr(model, module_name, nn.ModuleDict()) | ||
| elif isinstance(module_value, nn.ModuleList): | ||
| setattr(model, module_name, nn.ModuleList()) | ||
| # Handle simple module attributes (e.g., "linear", "norm") | ||
| elif module_name not in modules_to_keep: | ||
| # Replace with None | ||
| setattr(model, module_name, None) |
There was a problem hiding this comment.
can we reuse logic in _build_module_for_profile? or extract common logic as helper function with detailed comments of what it's doing?
| # This is used in the train loop to determine whether to pass in the input_ids and labels | ||
| has_first_stage = False | ||
| has_last_stage = False | ||
| for stage in stages: | ||
| if stage.is_first: | ||
| has_first_stage = True | ||
| if stage.is_last: | ||
| has_last_stage = True |
There was a problem hiding this comment.
has_first_stage = any(stage.is_first for stage in stages)
has_last_stage = any(stage.is_last for stage in stages)
| ) | ||
|
|
||
| if pp_schedule_csv: | ||
| assert schedule_class in [ |
There was a problem hiding this comment.
why do we assert this given L352?
| extra_layers = num_effective_layers % num_stages | ||
|
|
||
| # Feasibility check: Ensure at least 1 layer in each PP stage | ||
| if layers_per_stage == 0: |
There was a problem hiding this comment.
this check becomes unnecessary as we already checked in L434.
| ) | ||
|
|
||
| # Balance check: Ensure weights don't exceed minimum layers per stage | ||
| if input_weight > layers_per_stage: |
There was a problem hiding this comment.
also raise Error when input_weight < 0 or output_weight < 0
| if num_stages == 1: | ||
| # Single stage gets everything | ||
| layer_names = [f"layers.{i}" for i in range(num_layers)] | ||
| return [["tok_embeddings"] + layer_names + ["norm", "output"]] |
There was a problem hiding this comment.
the naming tok_embeddings, norm, output seems tied to specific model, can we generalize?
Thank you very much for the detailed guidance! We'll proceed by refactoring autopipe.py to make it more general-purpose and aligned with the design of torch.distributed.pipelining. We'll also add a new module (tentatively named _partition.py) under |
|
|
||
| for module_name in flatten_module_names: | ||
| # Create a copy of the base model for each module | ||
| model_copy = copy.deepcopy(base_model) |
There was a problem hiding this comment.
I saw quite a few deepcopys inside for loop to copy the entire model and then prune to only have target layers, will this cause to OOM for larger models? Can we think about other ways to do this?
There was a problem hiding this comment.
I saw quite a few deepcopys inside for loop to copy the entire model and then prune to only have target layers, will this cause to OOM for larger models? Can we think about other ways to do this?
This is an good question. We need to obtain the complete forward and backward inputs and outputs at the model layer level. However, extracting a single layer from the model's attribution cannot provide its full computational graph. Therefore, we adopt TorchTitan's pipeline_module_split approach: By capturing each layer's complete computational graph before initializing the model parameters, we can accurately compute FLOPs for each layer. During the current function, parameters remain uninitialized, thus preventing OOM issues that would otherwise occur with large models.
Auto-Partition in torchtitan
Overview
This PR provides an automatic partitioning method that considers the computation cost of embedding layers.
Thsi method involves calculating the floating-point operations (FLOPs) of the embedding layers and constructing an array that incorporates the FLOPs of both the transformer and embedding layers. Subsequently, a heuristic algorithm is employed to identify a balanced pipeline partition.
Solution Architecture
Performance
Hardware configuration: 4x RTX 3090 24GB, pipeline parallelism dimension is 4.
llama3 配置对比
deepseekv3(without moe) 配置对比