Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/integration_tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
"qwen3_fsdp+tp+ep+etp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--module qwen3 --config qwen3_debugmodel",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.context_parallel_degree 2",
],
],
"Qwen3 FSDP+TP+CP",
"qwen3_fsdp+tp+cp",
ngpu=8,
),
# Integration Test Cases for Llama 4
OverrideDefinitions(
[
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/test_compile_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch

from torchtitan.config import CompileConfig
from torchtitan.distributed.compile import apply_compile_sparse
from torchtitan.models.common.linear import Linear
from torchtitan.models.llama4.parallelize import apply_compile
from torchtitan.protocols.module import Module, ModuleDict


Expand Down Expand Up @@ -51,8 +51,8 @@ def test_patched_once(self):
unused_model2 = TinyModel(num_layers=2, dim=128)
compile_config = CompileConfig(backend="eager")

apply_compile(unused_model1, compile_config, ep_enabled=True)
apply_compile(unused_model2, compile_config, ep_enabled=True)
apply_compile_sparse(unused_model1, compile_config, ep_enabled=True)
apply_compile_sparse(unused_model2, compile_config, ep_enabled=True)

from torchtitan.models.common.moe import moe as moe_module

Expand Down
156 changes: 156 additions & 0 deletions torchtitan/distributed/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)

from torchtitan.config import CompileConfig
from torchtitan.models.common.moe import moe as moe_module
from torchtitan.tools.logging import logger


def apply_compile_dense(model: nn.Module, compile_config: CompileConfig) -> None:
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).

This is for dense (non-MoE) models. It compiles each TransformerBlock as a whole.
"""
# Skip replaying forward side effects (e.g. RoPE cache updates) during
# the AC recompute in backward. Eager AC replays the forward python
# side-effects in backward, but torch.compile has no easy way to reapply
# python mutations in the backward. Setting this flag accepts this eager
# and compile divergence by skipping reapplication of side effects.
torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = (
True # pyrefly: ignore [bad-assignment]
)

# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(
transformer_block, backend=compile_config.backend, fullgraph=True
)
# pyrefly: ignore [missing-attribute]
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")


def apply_compile_sparse(
model: nn.Module, compile_config: CompileConfig, ep_enabled: bool
) -> None:
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).

This is for MoE (sparse) models. It compiles sub-modules individually to avoid
graph breaks from FSDP(GroupedExperts).
"""
# Needed for torch.compile to avoid graph breaking on dynamic shapes in
# token-choice MoE, but it is experimental.
torch._dynamo.config.capture_scalar_outputs = True
# Skip replaying forward side effects (e.g. RoPE cache updates) during
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we have this side effects exactly? Is it not a concern?

Does this change numerics between eager and compile(aot_eager backend), when AC is enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this will make eager and compile converges, as noted in the comment.

from @anijain2305

Eager AC replays the forward python side-effects in backward. Many users do not like or need it. They just need the fwd torch ops to be replayed in backward.

Compile AC does not have a way to even replay python side-effects in backward. So it graph breaks. But if a user does not need to replay the python side effect in the backward (which is the case in torchtitan), they can use that flag.

The default is false because we would like to have compile AC and eager AC not diverge w/o user buy-in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if a user does not need to replay the python side effect in the backward (which is the case in torchtitan), they can use that flag.

This is saying "does not need", my question is "do we have" any side-effect in eager. I mean if we don't have any side effect, why do we need to toggle this field in the first place?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 might answer better here. I believe we should have side effects in eager(but idk what side effects exactly), and we need to skip when compile is turned on.

If we don't have this flag, we would face compile error

    File "/home/jessicazhong/torchtitan/torchtitan/models/common/decoder.py", line 151, in forward
      h = layer(h, self.freqs_cis, attention_masks, positions)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/jessicazhong/pytorch/torch/_dynamo/eval_frame.py", line 468, in __call__
      return super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/jessicazhong/pytorch/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/jessicazhong/pytorch/torch/nn/modules/module.py", line 1884, in _call_impl
      return inner()
             ^^^^^^^
    File "/home/jessicazhong/pytorch/torch/nn/modules/module.py", line 1832, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/jessicazhong/pytorch/torch/_dynamo/eval_frame.py", line 1046, in compile_wrapper
      raise e.with_traceback(None) from e.__cause__  # User compiler error
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  torch._dynamo.exc.Unsupported: HOP: Unsafe side effect
    Higher Order Operator: torch.utils.checkpoint.checkpoint
    Explanation: Mutating a variable from outside the scope of this HOP is not supported.
    Hint: If the HOP is activation checkpointing (torch.utils.checkpoint.checkpoint), this points to a side effect in forward method. Eager activation checkpointing replays that side-effect while recomputing the forward in the backward. If you are ok with side-effect not replayed in the backward, try setting `torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True`
  
    Developer debug context: Attempted to mutate FSDPManagedNNModuleVariable(ScaledDotProductAttentionWrapper)

Copy link
Contributor

@fegin fegin Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the Python mutation is safe to ignore (most of the time), then it is okay to set this flag as in the backward we don't need to replay the mutation, we just need to replay the forward ops. It looks like the mutation is from ScaledDotProductAttentionWrapper, which inherits LocalMapWrapper. And LocalMapWrapper call mutate self._local_map_fns. I think that's where the mutation. It is safe to ignore as long as we replay the correct forward ops.

#2621 works without setting this the flag because that PR swaps the forward BEFORE forward is called.

cc., @xmfan, @anijain2305 can confirm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude explains to me that the mutation comes from self._local_map_fn.

In torch.compile, the compiler traces the forward graph.
It sees the self._local_map_fn = ... mutation as a Python
side effect. During AC's backward recompute, compile
needs to replay forward ops but cannot replay arbitrary
Python attribute mutations. The
skip_fwd_side_effects_in_bwd_under_checkpoint flag tells
compile to skip these mutations during backward replay.

Why it's safe to skip: The mutation is a one-time lazy
cache. After the first forward, _local_map_fn is already
set. Skipping it during AC's backward recompute doesn't
affect the computation — the cached function is still
there and used correctly.

However I am wondering if #2621 is simpler, whether we should adopt this approach instead of inheriting from LocalMapWrapper.

# the AC recompute in backward. Eager AC replays the forward python
# side-effects in backward, but torch.compile has no easy way to reapply
# python mutations in the backward. Setting this flag accepts this eager
# and compile divergence by skipping reapplication of side effects.
torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = (
True # pyrefly: ignore [bad-assignment]
)

# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.named_children():
if transformer_block.moe_enabled:
# If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break
# So we must weave compile wrappers around those FSDP hooks to
# prevent AC from falling back the whole graph to eager.
# TODO: Fix Compile(AC(graph break))

if isinstance(transformer_block, CheckpointWrapper):
# TODO: Make CheckpointWrapper a transparent wrapper
# unwrap so that .named_children() works
block = transformer_block._checkpoint_wrapped_module
else:
block = transformer_block

for attr_name, submod in block.named_children():
assert getattr(block, attr_name) == getattr(
transformer_block, attr_name
)

if isinstance(submod, moe_module.MoE):
# avoid graph breaking on the GroupedExperts' FSDP hooks
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
# NOTE: We don't compile token dispatch and token combine due to an issue on B200:
# https://github.com/pytorch/torchtitan/issues/1940
continue
setattr(
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
else:
setattr(
block,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)

else:
# If it's not a MoE layer, there is no FSDP(GroupedExperts)
# So we can compile the whole block
transformer_block = torch.compile(
transformer_block,
backend=compile_config.backend,
fullgraph=True,
)

# pyrefly: ignore [missing-attribute]
model.layers.register_module(layer_id, transformer_block)

# Patch some globals only once (apply_compile_sparse is called multiple times for PP setup)
already_patched = (
"_run_experts_grouped_mm_dynamic"
in moe_module._run_experts_grouped_mm.__qualname__
)
if not already_patched:
moe_module._run_experts_grouped_mm = torch.compile(
moe_module._run_experts_grouped_mm,
backend=compile_config.backend,
fullgraph=True,
)

if ep_enabled:
compiled_fn = moe_module._run_experts_grouped_mm

# keep function logic in sync with `already_patched` above
def _run_experts_grouped_mm_dynamic(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)

moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic

# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
# https://github.com/pytorch/pytorch/issues/166460

logger.info("Compiling each TransformerBlock with torch.compile")
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
)
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.compile import apply_compile_dense
from torchtitan.distributed.fsdp import get_fsdp_reshard_after_forward_policy
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp, NoParallel
from torchtitan.models.llama3.parallelize import (
apply_compile,
apply_replicate,
disable_fsdp_gradient_division,
)
Expand Down Expand Up @@ -96,7 +96,7 @@ def parallelize_hf_transformers(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, compile_config)
apply_compile_dense(model, compile_config)

if parallel_dims.fsdp_enabled:
# apply FSDP or HSDP, potentially with Context Parallel
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/experiments/vlm/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
)
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.compile import apply_compile_dense
from torchtitan.distributed.fsdp import get_fsdp_reshard_after_forward_policy
from torchtitan.models.llama3.parallelize import (
_op_sac_save_list,
apply_compile,
apply_replicate,
disable_fsdp_gradient_division,
)
Expand Down Expand Up @@ -80,8 +80,8 @@ def parallelize_vlm(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if compile_config.enable:
apply_compile(model, compile_config)
apply_compile(model.encoder, compile_config)
apply_compile_dense(model, compile_config)
apply_compile_dense(model.encoder, compile_config)

if parallel_dims.fsdp_enabled:
# apply FSDP or HSDP, potentially with Context Parallel
Expand Down
93 changes: 89 additions & 4 deletions torchtitan/models/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch
import torch.nn.functional as F
from torch.distributed.tensor import DTensor, Shard
from torch.distributed.tensor.experimental import local_map
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
Expand All @@ -34,6 +36,7 @@
__all__ = [
"FlexAttentionWrapper",
"GQAttention",
"LocalMapAttention",
"ScaledDotProductAttentionWrapper",
"VarlenAttentionWrapper",
"VarlenMetadata",
Expand Down Expand Up @@ -61,16 +64,96 @@ class VarlenMetadata(NamedTuple):
AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata


class VarlenAttentionWrapper(Module):
class LocalMapAttention(Module):
"""Base class for inner attention wrappers with DTensor support.

When q, k, v are DTensors (e.g., from TP with ``use_local_output=False``),
overrides ``__call__`` to wrap ``nn.Module.__call__`` with ``local_map``.
This converts TP DTensors to local **before** any ``forward_pre_hook``
(e.g., CP's ``sdpa_input_fn``) fires, and wraps outputs back to TP
DTensors **after** all ``forward_hook``s complete.

Placements and device mesh are inferred from the input DTensors.
"""

@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
pass

def __init__(self) -> None:
super().__init__()
self._local_map_fn: Callable | None = None

def __call__(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
**kwargs,
) -> torch.Tensor:
if isinstance(q, DTensor):
assert isinstance(k, DTensor) and isinstance(
v, DTensor
), "q, k, v should all be DTensors"
# All placements must be Shard. We set
# out_placements and in_grad_placements equal to
# in_placements below. This is only valid for attention
# as qkv are sharded on the n_heads dim. CP is handled
# independently by _ContextParallel hooks inside
# nn.Module.__call__.
assert q.placements == k.placements == v.placements, (
f"q, k, v must have the same placements, "
f"but got q={q.placements}, k={k.placements}, v={v.placements}"
)
# qkv are (bs, n_heads, seqlen, head_dim) and must be sharded
# on the n_heads dim (dim 1)
# TODO: after full DTensor rewrite, the DP mesh will also be
# present, update this check to allow Shard(0) for DP and Shard(1) for TP.
for i, p in enumerate(q.placements):
assert p == Shard(1), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after full dtensor rewrite, DP mesh will be included, where qkv will be Shard(0) on DP mesh -- please leave a TODO

f"LocalMapAttention requires Shard(1) placements "
f"(n_heads dim), but got {p} at position {i}"
)
# return_lse=True (e.g. gpt_oss attention sinks) produces
# 2 outputs instead of 1, requiring different out_placements.
return_lse = kwargs.get("return_lse", False)
out_placements = (
(q.placements, q.placements) if return_lse else (q.placements,)
)
if self._local_map_fn is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the purpose of having this check is for 2nd, 3rd, ... call to this attention to reuse the _local_map_fn. I guess that's safe in torchtitan setup, where we never change the inputs placements from iteration to iteration, but we could've added a guard before this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already check placements at L105 for each call?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you didn't check Shard(0) / Shard(1) / Shard(2) / ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added! thanks for the catch

self._local_map_fn = local_map(
super().__call__,
in_placements=(q.placements, k.placements, v.placements),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better if we assert they all have the same placements

Copy link
Contributor Author

@acisseJZhong acisseJZhong Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought L103 does it? we checked q/k/v are Shard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

out_placements=out_placements,
in_grad_placements=(q.placements, k.placements, v.placements),
device_mesh=q.device_mesh,
)
# pyrefly: ignore [bad-argument-count]
return self._local_map_fn(q, k, v, **kwargs)
return super().__call__(q, k, v, **kwargs)

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
**kwargs,
) -> torch.Tensor:
raise NotImplementedError


class VarlenAttentionWrapper(LocalMapAttention):
_compiled_varlen_attn: ClassVar[Callable] = torch.compile(
varlen_attn, mode="max-autotune-no-cudagraphs"
)

# pyrefly: ignore [bad-param-name-override, bad-override]
def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
xv: torch.Tensor,
*,
attention_masks: VarlenMetadata,
scale: float | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -118,7 +201,7 @@ def forward(
).to(xq.dtype)


class FlexAttentionWrapper(Module):
class FlexAttentionWrapper(LocalMapAttention):
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.

This wrapper serves two purposes:
Expand Down Expand Up @@ -146,6 +229,7 @@ class FlexAttentionWrapper(Module):
options=inductor_configs,
)

# pyrefly: ignore [bad-override]
def forward(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -199,7 +283,7 @@ def annotate_flex_attention_for_regional_inductor() -> Generator[None, None, Non
FlexAttentionWrapper.forward = orig


class ScaledDotProductAttentionWrapper(Module):
class ScaledDotProductAttentionWrapper(LocalMapAttention):
"""Wrapper around `F.scaled_dot_product_attention` to make it CP compatible.

This wrapper is needed because `F.scaled_dot_product_attention` is not
Expand All @@ -222,6 +306,7 @@ def __init__(self) -> None:
SDPBackend.MATH,
]

# pyrefly: ignore [bad-override]
def forward(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -565,7 +650,7 @@ def forward(
case "varlen":
assert isinstance(attention_masks, VarlenMetadata), attention_masks
output = self.inner_attention(
xq, xk, xv, attention_masks, **scale_kwargs
xq, xk, xv, attention_masks=attention_masks, **scale_kwargs
)
case "sdpa":
assert attention_masks is None
Expand Down
Loading
Loading