Skip to content

perf(pipeline): implement auto-partition algorithm#2113

Open
TXacs wants to merge 11 commits intopytorch:mainfrom
McmillanTAC:autopartition
Open

perf(pipeline): implement auto-partition algorithm#2113
TXacs wants to merge 11 commits intopytorch:mainfrom
McmillanTAC:autopartition

Conversation

@TXacs
Copy link

@TXacs TXacs commented Dec 5, 2025

Auto-Partition in torchtitan

Overview

This PR provides an automatic partitioning method that considers the computation cost of embedding layers.
Thsi method involves calculating the floating-point operations (FLOPs) of the embedding layers and constructing an array that incorporates the FLOPs of both the transformer and embedding layers. Subsequently, a heuristic algorithm is employed to identify a balanced pipeline partition.

Solution Architecture

  1. Dynamic Cost Analysis
  2. Adaptive Partitioning Algorithm
  3. Workload Balancing

Performance

Hardware configuration: 4x RTX 3090 24GB, pipeline parallelism dimension is 4.

llama3 配置对比

hidden size layers autopipe TPS default TPS Speedup
dim=256 6 31,094 29,549 +5.2%
dim=256 12 21,803 21,923 -0.5%
dim=2048 12 3,348 2,616 +28.0%
dim=4096 12 981 761 +28.9%

deepseekv3(without moe) 配置对比

hidden size layers autopipe TPS default TPS Speedup
dim=256 6 13,373 13,059 +2.4%
dim=256 12 7,714 6,859 +12.5%
dim=2048 12 4,331 3,810 +13.7%
dim=4096 12 2,888 2,561 +12.8%
dim=4096 16 2,207 2,008 +9.9%
dim=8192 16 4,331 3,935 +10.1%

1. Improve pipeline performance
2. Auto partition modules
@meta-cla
Copy link

meta-cla bot commented Dec 5, 2025

Hi @TXacs!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Is it true that the only "real" deltas are

  • autopipe.cpp
  • pipeline_parallel.py
  • profiler.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks interesting -- how much benefit you'd get from having a c++ implementation, compared with a python one?

@TXacs
Copy link
Author

TXacs commented Dec 5, 2025

Thanks. Is it true that the only "real" deltas are

  • autopipe.cpp
  • pipeline_parallel.py
  • profiler.py

Yes,actually, profile.py also uses the file from DeepSpeed. It would be even better if TorchTitan could provide a more authoritative FLOPs calculation method in the future, so that we could also adapt it for MoE models.

@tianyu-l tianyu-l requested a review from H-Huang December 5, 2025 01:59
@meta-cla
Copy link

meta-cla bot commented Dec 5, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 5, 2025
@tianyu-l tianyu-l added the enhancement New feature or request label Dec 9, 2025

parts = pipeline(
mflops_list,
[i * 3 for i in mflops_list], # Assume backward is 3x forward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it assumed to be 3x?

Copy link

@McmillanTAC McmillanTAC Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the computation time, in which the backward computation time is twice that of the forward computation time, by default, we assume that the floating-point operations (FLOPs) required for backward computation are twice those for forward computation. During the modeling process, we take the recomputation technique into account. This technique inserts an additional forward computation before the backward computation. Consequently, we set the default FLOPs for backward computation to be three times those of forward computation.

# Profile each layer's FLOPS
mflops_list = []
for _, layer in enumerate(model):
prof = FlopsProfiler(layer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the FlopsProfiler does not estimate the backward flops?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct that the FlopsProfiler does not estimate backward flops. Instead, we use a heuristic rule by default: given the adoption of recomputation techniques in the backward computation, backward flops are set to three times those of the forward computation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why three times? IIUC the convention was two times.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the modeling process, we employ the recomputation technique, which introduces an additional forward pass during backpropagation. To control the computational overhead of backpropagation, we can adjust the implementation by setting a scaling factor—typically 2 or 3—depending on whether recomputation is enabled. This adjustment is orthogonal to the pipeline partitioning algorithm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the FlopsProfiler does not estimate the backward flops?

We have completely rebuilt the FLOPs acquisition logic—reverse FLOPs for each layer can now be computed according to the ACConfig strategy, and the AutoPipe algorithm has been re-implemented in pure Python.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the previous review, we have further refined this PR with the following key updates:

  1. Refactored the core algorithm from C++ to pure Python, enhancing maintainability and better aligning with the PyTorch ecosystem.
  2. Integrated PyTorch's official torch.utils.flop_counter.FlopCounterMode, replacing the third-party profiler to improve integration and reduce external dependencies.
  3. Refined the FLOPs calculation logic​ to dynamically determine backward pass costs based on the activation checkpointing configuration, moving beyond a fixed multiplier for greater accuracy.

We welcome any additional suggestions to move this forward. Thanks again for your review.

…figuration

Based on TorchTitan's configuration that only applies recomputation to transformer layers, the algorithm has been modified. When using ac_config.selective_ac_option = 'nlayers', performance shows a 3% to 17% improvement compared to the previous algorithm that all layers defaulted to 3× forward FLOPs.
1. Re-implemented AutoPipe logic in Python, removed legacy C++ codebase
2. Replaced DeepSpeed profiler with torch.utils.flop_counter.FlopCounterMode
3. Added layer-wise reverse-FLOPs computation based on ACConfig‘s recomputation policy
@tianyu-l tianyu-l requested a review from H-Huang January 27, 2026 08:38

### Compile

First, we need to compile `autopipe.cpp`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment outdated? i do not see any .cpp files in the PR anymore.

Copy link

@McmillanTAC McmillanTAC Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this comment is outdated. The PR has migrated AutoPipe logic from C++ to pure Python, so there are no .cpp files involved anymore. We should remove or update this line to reflect the current implementation. Sorry for the confusion caused by the outdated documentation!

)
seq_modules = _build_module_for_profile(copied_model, flatten_module_names)

parts = autopipe_partition(seq_modules, parallel_dims.pp, job_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do the autopipe_partition after the splitting is already done? I propose there should be 3 different ways of doing things:

  1. layer-wise static partitioning (default)
  2. config based (user sets layers per stage)
  3. autopipe partition (this new implementation)

The option to choose should be set by configuration. Further, if the autopipe way is determined to be strictly better than it should replace the default.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah,It's already fixed in dc9ebab

logger.info(f"Autopipe partitioning with mflops: {mflops_fwd}, {mflops_bwd}")

# Partition layers by forward and backward flops
parts = pipeline(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this pipeline name conflicts with naming provided in pytorch under torch.distributed.pipelining so maybe should have a different name

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We changed it in dc9ebab

tokenizer = build_hf_tokenizer(job_config)

# build dataloader
dataloader = build_text_dataloader(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not ideal that we have to create a new dataloader for each rank. Can we not use dummy data as the example inputs instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing the code. We considered that different models have different data formats, which is why using the original data loader can help us avoid many issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why data format matters for FLOP counting? dummy data should be sufficient as long as it has the correct shape and dtype?


for layer_idx, layer in enumerate(model):
# forward
with FlopCounterMode(display=False) as mode:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this may be an inherent limitation. IIUC FlopCounterMode only counts flops for registered pytorch ops. So it would not capture things like FlexAttention which uses fused kernels, it may also not cover torch.compile related optimizations. Do you happen to know @sanketpurandare? How did you do things for your runtime estimators?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insightful comment! You're absolutely right. FlopCounterMode has limitations with fused kernels and torch.compile related optimizations. However, our approach is modular and extensible, so it can integrate more accurate cost models as they emerge. We treat FLOPs as a baseline, not the final answer. We're open to incorporating better estimators. Would love to hear your thoughts on how you bridge this gap!

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the PR should be split up into two parts:

  1. move autopipe.py to pytorch/pytorch under torch/distributed/pipelining/_utils.py

  2. Update pipeline_parallel.py to import this function from pytorch and perform the split (based on configuration)

The rest of the changes in this PR can be removed.

return h

# C++ VectorEqual
def _eq(a, b):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your careful review! This function is currently unused, and its presence was an oversight on our part. We sincerely apologize for the confusion and appreciate you pointing it out. We’ll remove the unused function in the next update.


# Hash func: C++ VectorHash
@functools.cache
def _hash(p):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python has its own hash standard library you should use

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback! We will update the code accordingly to use the official implementation.

import functools
from collections import deque

COMM_OVERHEAD = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this 0?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set the communication overhead to 0 because the asynchronous communication is fully overlapped with computation. This effectively hides the communication latency, making the observed overhead negligible in our experimental setup.

):
"""C++ calculate_stage_times"""
num_stages = len(partition)
num_micro = num_stages * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_stages and num_micro should both be hyperparameters, this assumes num_micro is always 2 * stages?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_stages used in this context is a local variable, while the global num_stages is indeed already treated as a hyperparameter. As for num_micro, we currently default it to 2 * stages because this configuration ensures the pipeline naturally forms three distinct phases: warmup, steady, and cooldown. We agree that num_micro should be configurable as a hyperparameter to support more flexible pipeline configurations. We’ll update the code accordingly.

return flush, critical


# ---------- 最优分区搜索 ----------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the chinese comments be translated to english for consistency?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We’ll translate all Chinese comments into English to maintain consistency across the codebase. Thanks for pointing this out!

@McmillanTAC
Copy link

I think the PR should be split up into two parts:

  1. move autopipe.py to pytorch/pytorch under torch/distributed/pipelining/_utils.py
  2. Update pipeline_parallel.py to import this function from pytorch and perform the split (based on configuration)

The rest of the changes in this PR can be removed.

Thank you for the helpful suggestion! Could you kindly clarify the expected steps for moving autopipe.py into the PyTorch repository?

…h function with Python's built-in hash();3. Translate Chinese to English
@H-Huang
Copy link
Member

H-Huang commented Mar 5, 2026

Could you kindly clarify the expected steps for moving autopipe.py into the PyTorch repository?

You will need to clean up the file to be more generalizable torch.distributed.pipelining. The algorithm with the DP solver is fine (no model / torch dependency), but there are a few hardcoded assumptions, namely:

  • single stage schedule (1F1B)
  • fixed number of microbatches based on stages
  • COMM_OVERHEAD
  • no support for split backward schedules (zero bubble)

I don't think these need to all be solved as once, but if you can get a proof of concept upstreamed to PyTorch that is configurable (remove some of these hardcoded assumptions), then torchtitan can just perform the plumbing to import this and use it.

So concretely, you can upstream the autopipe algorithm to pytorch/torch/distributed/pipelining, you can all it _partition.py and also add a new test file for it to pytorch/test/distributed/pipelining for include unit tests.

@tianyu-l tianyu-l requested a review from acisseJZhong March 5, 2026 19:20
tokenizer = build_hf_tokenizer(job_config)

# build dataloader
dataloader = build_text_dataloader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why data format matters for FLOP counting? dummy data should be sufficient as long as it has the correct shape and dtype?

nn.Sequential: A sequential module containing the specified modules
"""
module_seq = nn.Sequential()
base_model = copy.deepcopy(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant to L149?

Comment on lines +554 to +596
def _build_stage_from_modules(
stage_idx: int, module_names: list[str], num_stages: int
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)

# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}
if layers_to_keep:
# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name not in layers_to_keep:
del module_value[layer_name]
elif isinstance(module_value, nn.ModuleList):
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = nn.ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(model, module_name, new_layers)
else:
# No layers from this structure needed, set to empty structure
if isinstance(module_value, nn.ModuleDict):
setattr(model, module_name, nn.ModuleDict())
elif isinstance(module_value, nn.ModuleList):
setattr(model, module_name, nn.ModuleList())
# Handle simple module attributes (e.g., "linear", "norm")
elif module_name not in modules_to_keep:
# Replace with None
setattr(model, module_name, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse logic in _build_module_for_profile? or extract common logic as helper function with detailed comments of what it's doing?

Comment on lines +319 to +326
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_first_stage = any(stage.is_first for stage in stages)
has_last_stage = any(stage.is_last for stage in stages)

)

if pp_schedule_csv:
assert schedule_class in [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we assert this given L352?

extra_layers = num_effective_layers % num_stages

# Feasibility check: Ensure at least 1 layer in each PP stage
if layers_per_stage == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check becomes unnecessary as we already checked in L434.

)

# Balance check: Ensure weights don't exceed minimum layers per stage
if input_weight > layers_per_stage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also raise Error when input_weight < 0 or output_weight < 0

if num_stages == 1:
# Single stage gets everything
layer_names = [f"layers.{i}" for i in range(num_layers)]
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the naming tok_embeddings, norm, output seems tied to specific model, can we generalize?

@McmillanTAC
Copy link

Could you kindly clarify the expected steps for moving autopipe.py into the PyTorch repository?

You will need to clean up the file to be more generalizable torch.distributed.pipelining. The algorithm with the DP solver is fine (no model / torch dependency), but there are a few hardcoded assumptions, namely:

  • single stage schedule (1F1B)
  • fixed number of microbatches based on stages
  • COMM_OVERHEAD
  • no support for split backward schedules (zero bubble)

I don't think these need to all be solved as once, but if you can get a proof of concept upstreamed to PyTorch that is configurable (remove some of these hardcoded assumptions), then torchtitan can just perform the plumbing to import this and use it.

So concretely, you can upstream the autopipe algorithm to pytorch/torch/distributed/pipelining, you can all it _partition.py and also add a new test file for it to pytorch/test/distributed/pipelining for include unit tests.

Thank you very much for the detailed guidance! We'll proceed by refactoring autopipe.py to make it more general-purpose and aligned with the design of torch.distributed.pipelining. We'll also add a new module (tentatively named _partition.py) under pytorch/torch/distributed/pipelining along with a corresponding unit test file in pytorch/test/distributed/pipelining.
We appreciate your support and will aim to submit a clean, minimal proof-of-concept PR to PyTorch soon!


for module_name in flatten_module_names:
# Create a copy of the base model for each module
model_copy = copy.deepcopy(base_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw quite a few deepcopys inside for loop to copy the entire model and then prune to only have target layers, will this cause to OOM for larger models? Can we think about other ways to do this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw quite a few deepcopys inside for loop to copy the entire model and then prune to only have target layers, will this cause to OOM for larger models? Can we think about other ways to do this?

This is an good question. We need to obtain the complete forward and backward inputs and outputs at the model layer level. However, extracting a single layer from the model's attribution cannot provide its full computational graph. Therefore, we adopt TorchTitan's pipeline_module_split approach: By capturing each layer's complete computational graph before initializing the model parameters, we can accurately compute FLOPs for each layer. During the current function, parameters remain uninitialized, thus preventing OOM issues that would otherwise occur with large models.

@wwwjn wwwjn assigned acisseJZhong and unassigned wwwjn Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. enhancement New feature or request

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

7 participants