-
Notifications
You must be signed in to change notification settings - Fork 754
Adopt Local Map Wrapper for Inner Attention #2557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
06c4eff
0cea461
958a45d
96be997
7043094
310362b
2e0500a
575c888
fab34ab
d3b8c49
4119609
1047c86
e344eb4
e2abec1
bc95a40
d559eb3
ca647f6
8534f02
dcd32e8
b726785
38899c1
003bb42
0eb89eb
d14b720
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
| CheckpointWrapper, | ||
| ) | ||
|
|
||
| from torchtitan.config import CompileConfig | ||
| from torchtitan.models.common.moe import moe as moe_module | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
||
| def apply_compile_dense(model: nn.Module, compile_config: CompileConfig) -> None: | ||
| """ | ||
| Apply torch.compile to each TransformerBlock, which makes compilation efficient due to | ||
| repeated structure. Alternatively one can compile the whole model (after applying DP). | ||
|
|
||
| This is for dense (non-MoE) models. It compiles each TransformerBlock as a whole. | ||
| """ | ||
| # Skip replaying forward side effects (e.g. RoPE cache updates) during | ||
| # the AC recompute in backward. Eager AC replays the forward python | ||
| # side-effects in backward, but torch.compile has no easy way to reapply | ||
| # python mutations in the backward. Setting this flag accepts this eager | ||
| # and compile divergence by skipping reapplication of side effects. | ||
| torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = ( | ||
| True # pyrefly: ignore [bad-assignment] | ||
| ) | ||
|
|
||
| # pyrefly: ignore [missing-attribute] | ||
| for layer_id, transformer_block in model.layers.named_children(): | ||
| transformer_block = torch.compile( | ||
| transformer_block, backend=compile_config.backend, fullgraph=True | ||
| ) | ||
| # pyrefly: ignore [missing-attribute] | ||
| model.layers.register_module(layer_id, transformer_block) | ||
|
|
||
| logger.info("Compiling each TransformerBlock with torch.compile") | ||
|
|
||
|
|
||
| def apply_compile_sparse( | ||
| model: nn.Module, compile_config: CompileConfig, ep_enabled: bool | ||
| ) -> None: | ||
| """ | ||
| Apply torch.compile to each TransformerBlock, which makes compilation efficient due to | ||
| repeated structure. Alternatively one can compile the whole model (after applying DP). | ||
|
|
||
| This is for MoE (sparse) models. It compiles sub-modules individually to avoid | ||
| graph breaks from FSDP(GroupedExperts). | ||
| """ | ||
| # Needed for torch.compile to avoid graph breaking on dynamic shapes in | ||
| # token-choice MoE, but it is experimental. | ||
| torch._dynamo.config.capture_scalar_outputs = True | ||
| # Skip replaying forward side effects (e.g. RoPE cache updates) during | ||
| # the AC recompute in backward. Eager AC replays the forward python | ||
| # side-effects in backward, but torch.compile has no easy way to reapply | ||
| # python mutations in the backward. Setting this flag accepts this eager | ||
| # and compile divergence by skipping reapplication of side effects. | ||
| torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = ( | ||
| True # pyrefly: ignore [bad-assignment] | ||
| ) | ||
|
|
||
| # pyrefly: ignore [missing-attribute] | ||
| for layer_id, transformer_block in model.layers.named_children(): | ||
| if transformer_block.moe_enabled: | ||
| # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break | ||
| # So we must weave compile wrappers around those FSDP hooks to | ||
| # prevent AC from falling back the whole graph to eager. | ||
| # TODO: Fix Compile(AC(graph break)) | ||
|
|
||
| if isinstance(transformer_block, CheckpointWrapper): | ||
| # TODO: Make CheckpointWrapper a transparent wrapper | ||
| # unwrap so that .named_children() works | ||
| block = transformer_block._checkpoint_wrapped_module | ||
| else: | ||
| block = transformer_block | ||
|
|
||
| for attr_name, submod in block.named_children(): | ||
| assert getattr(block, attr_name) == getattr( | ||
| transformer_block, attr_name | ||
| ) | ||
|
|
||
| if isinstance(submod, moe_module.MoE): | ||
| # avoid graph breaking on the GroupedExperts' FSDP hooks | ||
| # by wrapping each submod's forward instead of their __call__ | ||
| moe = submod | ||
| for attr_name, submod in moe.named_children(): | ||
| if attr_name == "experts": | ||
| # NOTE: We don't compile token dispatch and token combine due to an issue on B200: | ||
| # https://github.com/pytorch/torchtitan/issues/1940 | ||
| continue | ||
| setattr( | ||
| moe, | ||
| attr_name, | ||
| torch.compile( | ||
| submod, backend=compile_config.backend, fullgraph=True | ||
| ), | ||
| ) | ||
| else: | ||
| setattr( | ||
| block, | ||
| attr_name, | ||
| torch.compile( | ||
| submod, backend=compile_config.backend, fullgraph=True | ||
| ), | ||
| ) | ||
|
|
||
| else: | ||
| # If it's not a MoE layer, there is no FSDP(GroupedExperts) | ||
| # So we can compile the whole block | ||
| transformer_block = torch.compile( | ||
| transformer_block, | ||
| backend=compile_config.backend, | ||
| fullgraph=True, | ||
| ) | ||
|
|
||
| # pyrefly: ignore [missing-attribute] | ||
| model.layers.register_module(layer_id, transformer_block) | ||
|
|
||
| # Patch some globals only once (apply_compile_sparse is called multiple times for PP setup) | ||
| already_patched = ( | ||
| "_run_experts_grouped_mm_dynamic" | ||
| in moe_module._run_experts_grouped_mm.__qualname__ | ||
| ) | ||
| if not already_patched: | ||
| moe_module._run_experts_grouped_mm = torch.compile( | ||
| moe_module._run_experts_grouped_mm, | ||
| backend=compile_config.backend, | ||
| fullgraph=True, | ||
| ) | ||
|
|
||
| if ep_enabled: | ||
| compiled_fn = moe_module._run_experts_grouped_mm | ||
|
|
||
| # keep function logic in sync with `already_patched` above | ||
| def _run_experts_grouped_mm_dynamic( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # dynamic number of tokens in expert parallel | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
| return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) | ||
|
|
||
| moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic | ||
|
|
||
| # NOTE: We don't compile for loop code path due to an issue with unbacked symints: | ||
| # https://github.com/pytorch/pytorch/issues/166460 | ||
|
|
||
| logger.info("Compiling each TransformerBlock with torch.compile") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,8 @@ | |
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.distributed.tensor import DTensor, Shard | ||
| from torch.distributed.tensor.experimental import local_map | ||
| from torch.nn.attention import sdpa_kernel, SDPBackend | ||
| from torch.nn.attention.flex_attention import ( | ||
| _mask_mod_signature, | ||
|
|
@@ -34,6 +36,7 @@ | |
| __all__ = [ | ||
| "FlexAttentionWrapper", | ||
| "GQAttention", | ||
| "LocalMapAttention", | ||
| "ScaledDotProductAttentionWrapper", | ||
| "VarlenAttentionWrapper", | ||
| "VarlenMetadata", | ||
|
|
@@ -61,16 +64,96 @@ class VarlenMetadata(NamedTuple): | |
| AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata | ||
|
|
||
|
|
||
| class VarlenAttentionWrapper(Module): | ||
| class LocalMapAttention(Module): | ||
| """Base class for inner attention wrappers with DTensor support. | ||
|
|
||
| When q, k, v are DTensors (e.g., from TP with ``use_local_output=False``), | ||
| overrides ``__call__`` to wrap ``nn.Module.__call__`` with ``local_map``. | ||
| This converts TP DTensors to local **before** any ``forward_pre_hook`` | ||
| (e.g., CP's ``sdpa_input_fn``) fires, and wraps outputs back to TP | ||
| DTensors **after** all ``forward_hook``s complete. | ||
|
|
||
| Placements and device mesh are inferred from the input DTensors. | ||
| """ | ||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Module.Config): | ||
| pass | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self._local_map_fn: Callable | None = None | ||
|
|
||
| def __call__( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| if isinstance(q, DTensor): | ||
| assert isinstance(k, DTensor) and isinstance( | ||
| v, DTensor | ||
| ), "q, k, v should all be DTensors" | ||
| # All placements must be Shard. We set | ||
| # out_placements and in_grad_placements equal to | ||
| # in_placements below. This is only valid for attention | ||
| # as qkv are sharded on the n_heads dim. CP is handled | ||
| # independently by _ContextParallel hooks inside | ||
| # nn.Module.__call__. | ||
| assert q.placements == k.placements == v.placements, ( | ||
| f"q, k, v must have the same placements, " | ||
| f"but got q={q.placements}, k={k.placements}, v={v.placements}" | ||
| ) | ||
| # qkv are (bs, n_heads, seqlen, head_dim) and must be sharded | ||
| # on the n_heads dim (dim 1) | ||
| # TODO: after full DTensor rewrite, the DP mesh will also be | ||
| # present, update this check to allow Shard(0) for DP and Shard(1) for TP. | ||
| for i, p in enumerate(q.placements): | ||
| assert p == Shard(1), ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. after full dtensor rewrite, DP mesh will be included, where qkv will be Shard(0) on DP mesh -- please leave a TODO |
||
| f"LocalMapAttention requires Shard(1) placements " | ||
| f"(n_heads dim), but got {p} at position {i}" | ||
| ) | ||
| # return_lse=True (e.g. gpt_oss attention sinks) produces | ||
| # 2 outputs instead of 1, requiring different out_placements. | ||
| return_lse = kwargs.get("return_lse", False) | ||
| out_placements = ( | ||
| (q.placements, q.placements) if return_lse else (q.placements,) | ||
| ) | ||
| if self._local_map_fn is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC the purpose of having this check is for 2nd, 3rd, ... call to this attention to reuse the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already check placements at L105 for each call?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you didn't check Shard(0) / Shard(1) / Shard(2) / ...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added! thanks for the catch |
||
| self._local_map_fn = local_map( | ||
| super().__call__, | ||
| in_placements=(q.placements, k.placements, v.placements), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it better if we assert they all have the same placements
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought L103 does it? we checked q/k/v are Shard.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment |
||
| out_placements=out_placements, | ||
| in_grad_placements=(q.placements, k.placements, v.placements), | ||
| device_mesh=q.device_mesh, | ||
| ) | ||
| # pyrefly: ignore [bad-argument-count] | ||
| return self._local_map_fn(q, k, v, **kwargs) | ||
| return super().__call__(q, k, v, **kwargs) | ||
|
|
||
| def forward( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class VarlenAttentionWrapper(LocalMapAttention): | ||
| _compiled_varlen_attn: ClassVar[Callable] = torch.compile( | ||
| varlen_attn, mode="max-autotune-no-cudagraphs" | ||
| ) | ||
|
|
||
| # pyrefly: ignore [bad-param-name-override, bad-override] | ||
| def forward( | ||
| self, | ||
| xq: torch.Tensor, | ||
| xk: torch.Tensor, | ||
| xv: torch.Tensor, | ||
| *, | ||
| attention_masks: VarlenMetadata, | ||
| scale: float | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
|
|
@@ -118,7 +201,7 @@ def forward( | |
| ).to(xq.dtype) | ||
|
|
||
|
|
||
| class FlexAttentionWrapper(Module): | ||
| class FlexAttentionWrapper(LocalMapAttention): | ||
| """Wrapper around `flex_attention` to make it torch.compile and CP compatible. | ||
|
|
||
| This wrapper serves two purposes: | ||
|
|
@@ -146,6 +229,7 @@ class FlexAttentionWrapper(Module): | |
| options=inductor_configs, | ||
| ) | ||
|
|
||
| # pyrefly: ignore [bad-override] | ||
| def forward( | ||
| self, | ||
| q: torch.Tensor, | ||
|
|
@@ -199,7 +283,7 @@ def annotate_flex_attention_for_regional_inductor() -> Generator[None, None, Non | |
| FlexAttentionWrapper.forward = orig | ||
|
|
||
|
|
||
| class ScaledDotProductAttentionWrapper(Module): | ||
| class ScaledDotProductAttentionWrapper(LocalMapAttention): | ||
| """Wrapper around `F.scaled_dot_product_attention` to make it CP compatible. | ||
|
|
||
| This wrapper is needed because `F.scaled_dot_product_attention` is not | ||
|
|
@@ -222,6 +306,7 @@ def __init__(self) -> None: | |
| SDPBackend.MATH, | ||
| ] | ||
|
|
||
| # pyrefly: ignore [bad-override] | ||
| def forward( | ||
| self, | ||
| q: torch.Tensor, | ||
|
|
@@ -565,7 +650,7 @@ def forward( | |
| case "varlen": | ||
| assert isinstance(attention_masks, VarlenMetadata), attention_masks | ||
| output = self.inner_attention( | ||
| xq, xk, xv, attention_masks, **scale_kwargs | ||
| xq, xk, xv, attention_masks=attention_masks, **scale_kwargs | ||
| ) | ||
| case "sdpa": | ||
| assert attention_masks is None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we have this side effects exactly? Is it not a concern?
Does this change numerics between eager and compile(aot_eager backend), when AC is enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this will make eager and compile converges, as noted in the comment.
from @anijain2305
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is saying "does not need", my question is "do we have" any side-effect in eager. I mean if we don't have any side effect, why do we need to toggle this field in the first place?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anijain2305 might answer better here. I believe we should have side effects in eager(but idk what side effects exactly), and we need to skip when compile is turned on.
If we don't have this flag, we would face compile error
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the Python mutation is safe to ignore (most of the time), then it is okay to set this flag as in the backward we don't need to replay the mutation, we just need to replay the forward ops. It looks like the mutation is from ScaledDotProductAttentionWrapper, which inherits LocalMapWrapper. And LocalMapWrapper call mutate
self._local_map_fns. I think that's where the mutation. It is safe to ignore as long as we replay the correct forward ops.#2621 works without setting this the flag because that PR swaps the forward BEFORE forward is called.
cc., @xmfan, @anijain2305 can confirm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude explains to me that the mutation comes from
self._local_map_fn.However I am wondering if #2621 is simpler, whether we should adopt this approach instead of inheriting from
LocalMapWrapper.