Skip to content

Commit 4dfaa04

Browse files
committed
[DONT LAND] Full Dtensor fully_shard + Local Map + FlexAttention
ghstack-source-id: 4220989 Pull-Request: #2621
1 parent 912208f commit 4dfaa04

File tree

1 file changed

+78
-3
lines changed

1 file changed

+78
-3
lines changed

torchtitan/distributed/full_dtensor.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
# This file contains utility functions to parallelize models with full dtensor.
88
# We will eventually replace the existing functions with these or merge them.
99

10+
from collections.abc import Callable
1011
from typing import Any
1112

1213
import torch
1314
import torch.nn as nn
1415
from torch.distributed.device_mesh import DeviceMesh
1516
from torch.distributed.fsdp import DataParallelMeshDims
1617
from torch.distributed.tensor import distribute_module, DTensor
18+
19+
from torch.distributed.tensor.experimental import local_map
1720
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
1821

1922
from torchtitan.distributed.parallel_dims import ParallelDims
@@ -52,11 +55,11 @@ def validate_config(parallel_dims: ParallelDims, model_config: Any) -> None:
5255
layer = getattr(model_config, "layer", None)
5356
attn_config = getattr(layer, "attention", None) if layer else None
5457
attn_backend = getattr(attn_config, "attn_backend", "sdpa")
55-
if attn_backend in ("flex", "varlen"):
58+
if attn_backend == "varlen":
5659
raise NotImplementedError(
5760
f"full_dtensor is not supported with {attn_backend} attention. "
58-
"Flex/varlen attention does not support DTensor dispatch. "
59-
"Use sdpa attention or disable full_dtensor."
61+
"Varlen attention does not support DTensor dispatch. "
62+
"Use sdpa or flex attention, or disable full_dtensor."
6063
)
6164

6265

@@ -107,6 +110,73 @@ def _remove_sdpa_math_backend(model: nn.Module) -> None:
107110
]
108111

109112

113+
def _wrap_inner_attention_with_local_map(model: nn.Module) -> None:
114+
"""Wrap inner attention modules' forward with local_map.
115+
116+
In full DTensor mode, attention inputs (q, k, v) are DTensors. Flex
117+
attention mask functions (e.g. document_mask) operate on plain tensors
118+
and cannot handle DTensor indices. Wrapping forward with local_map
119+
converts DTensor inputs to local tensors before forward runs, and wraps
120+
outputs back to DTensors afterward.
121+
122+
Placements and device mesh are inferred from the input DTensors at
123+
runtime, following the sixlib pattern of wrapping self.forward at setup
124+
time. The local_map wrapper is lazily created on first call since FSDP
125+
may change placements after distribute_module.
126+
"""
127+
from torchtitan.models.common.attention import (
128+
FlexAttentionWrapper,
129+
ScaledDotProductAttentionWrapper,
130+
VarlenAttentionWrapper,
131+
)
132+
133+
inner_attn_types = (
134+
FlexAttentionWrapper,
135+
ScaledDotProductAttentionWrapper,
136+
VarlenAttentionWrapper,
137+
)
138+
139+
for name, module in model.named_modules():
140+
if isinstance(module, inner_attn_types):
141+
original_forward = module.forward
142+
143+
# Cache for the local_map-wrapped function, keyed by placements
144+
# so we don't recreate it every call.
145+
cached_fn: dict[tuple, Callable] = {}
146+
147+
def make_wrapped_forward(orig_fn: Callable) -> Callable:
148+
def wrapped_forward(
149+
q: torch.Tensor,
150+
k: torch.Tensor,
151+
v: torch.Tensor,
152+
**kwargs,
153+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
154+
if not isinstance(q, DTensor):
155+
return orig_fn(q, k, v, **kwargs)
156+
157+
assert isinstance(k, DTensor) and isinstance(v, DTensor)
158+
159+
# Infer placements from the input DTensors
160+
cache_key = (q.placements, k.placements, v.placements)
161+
if cache_key not in cached_fn:
162+
cached_fn[cache_key] = local_map(
163+
orig_fn,
164+
out_placements=(q.placements,),
165+
in_placements=(
166+
q.placements,
167+
k.placements,
168+
v.placements,
169+
),
170+
device_mesh=q.device_mesh,
171+
)
172+
return cached_fn[cache_key](q, k, v, **kwargs)
173+
174+
return wrapped_forward
175+
176+
module.forward = make_wrapped_forward(original_forward)
177+
logger.debug(f"Wrapped {name} forward with local_map")
178+
179+
110180
def _find_tied_parameters(
111181
model: nn.Module,
112182
) -> list[list[tuple[nn.Module, str]]]:
@@ -197,6 +267,11 @@ def distribute_model(
197267
f"after distribute_module"
198268
)
199269

270+
# Wrap inner attention modules' forward with local_map so that
271+
# DTensor inputs are converted to local tensors before attention
272+
# computation (e.g. flex attention mask functions expect plain tensors).
273+
_wrap_inner_attention_with_local_map(model)
274+
200275
logger.info(
201276
f"Distributed model parameters as DTensors on SPMD mesh "
202277
f"with dims {spmd_mesh.mesh_dim_names}"

0 commit comments

Comments
 (0)