|
7 | 7 | # This file contains utility functions to parallelize models with full dtensor. |
8 | 8 | # We will eventually replace the existing functions with these or merge them. |
9 | 9 |
|
| 10 | +from collections.abc import Callable |
10 | 11 | from typing import Any |
11 | 12 |
|
12 | 13 | import torch |
13 | 14 | import torch.nn as nn |
14 | 15 | from torch.distributed.device_mesh import DeviceMesh |
15 | 16 | from torch.distributed.fsdp import DataParallelMeshDims |
16 | 17 | from torch.distributed.tensor import distribute_module, DTensor |
| 18 | + |
| 19 | +from torch.distributed.tensor.experimental import local_map |
17 | 20 | from torch.distributed.tensor.placement_types import Placement, Replicate, Shard |
18 | 21 |
|
19 | 22 | from torchtitan.distributed.parallel_dims import ParallelDims |
@@ -52,11 +55,11 @@ def validate_config(parallel_dims: ParallelDims, model_config: Any) -> None: |
52 | 55 | layer = getattr(model_config, "layer", None) |
53 | 56 | attn_config = getattr(layer, "attention", None) if layer else None |
54 | 57 | attn_backend = getattr(attn_config, "attn_backend", "sdpa") |
55 | | - if attn_backend in ("flex", "varlen"): |
| 58 | + if attn_backend == "varlen": |
56 | 59 | raise NotImplementedError( |
57 | 60 | f"full_dtensor is not supported with {attn_backend} attention. " |
58 | | - "Flex/varlen attention does not support DTensor dispatch. " |
59 | | - "Use sdpa attention or disable full_dtensor." |
| 61 | + "Varlen attention does not support DTensor dispatch. " |
| 62 | + "Use sdpa or flex attention, or disable full_dtensor." |
60 | 63 | ) |
61 | 64 |
|
62 | 65 |
|
@@ -107,6 +110,73 @@ def _remove_sdpa_math_backend(model: nn.Module) -> None: |
107 | 110 | ] |
108 | 111 |
|
109 | 112 |
|
| 113 | +def _wrap_inner_attention_with_local_map(model: nn.Module) -> None: |
| 114 | + """Wrap inner attention modules' forward with local_map. |
| 115 | +
|
| 116 | + In full DTensor mode, attention inputs (q, k, v) are DTensors. Flex |
| 117 | + attention mask functions (e.g. document_mask) operate on plain tensors |
| 118 | + and cannot handle DTensor indices. Wrapping forward with local_map |
| 119 | + converts DTensor inputs to local tensors before forward runs, and wraps |
| 120 | + outputs back to DTensors afterward. |
| 121 | +
|
| 122 | + Placements and device mesh are inferred from the input DTensors at |
| 123 | + runtime, following the sixlib pattern of wrapping self.forward at setup |
| 124 | + time. The local_map wrapper is lazily created on first call since FSDP |
| 125 | + may change placements after distribute_module. |
| 126 | + """ |
| 127 | + from torchtitan.models.common.attention import ( |
| 128 | + FlexAttentionWrapper, |
| 129 | + ScaledDotProductAttentionWrapper, |
| 130 | + VarlenAttentionWrapper, |
| 131 | + ) |
| 132 | + |
| 133 | + inner_attn_types = ( |
| 134 | + FlexAttentionWrapper, |
| 135 | + ScaledDotProductAttentionWrapper, |
| 136 | + VarlenAttentionWrapper, |
| 137 | + ) |
| 138 | + |
| 139 | + for name, module in model.named_modules(): |
| 140 | + if isinstance(module, inner_attn_types): |
| 141 | + original_forward = module.forward |
| 142 | + |
| 143 | + # Cache for the local_map-wrapped function, keyed by placements |
| 144 | + # so we don't recreate it every call. |
| 145 | + cached_fn: dict[tuple, Callable] = {} |
| 146 | + |
| 147 | + def make_wrapped_forward(orig_fn: Callable) -> Callable: |
| 148 | + def wrapped_forward( |
| 149 | + q: torch.Tensor, |
| 150 | + k: torch.Tensor, |
| 151 | + v: torch.Tensor, |
| 152 | + **kwargs, |
| 153 | + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
| 154 | + if not isinstance(q, DTensor): |
| 155 | + return orig_fn(q, k, v, **kwargs) |
| 156 | + |
| 157 | + assert isinstance(k, DTensor) and isinstance(v, DTensor) |
| 158 | + |
| 159 | + # Infer placements from the input DTensors |
| 160 | + cache_key = (q.placements, k.placements, v.placements) |
| 161 | + if cache_key not in cached_fn: |
| 162 | + cached_fn[cache_key] = local_map( |
| 163 | + orig_fn, |
| 164 | + out_placements=(q.placements,), |
| 165 | + in_placements=( |
| 166 | + q.placements, |
| 167 | + k.placements, |
| 168 | + v.placements, |
| 169 | + ), |
| 170 | + device_mesh=q.device_mesh, |
| 171 | + ) |
| 172 | + return cached_fn[cache_key](q, k, v, **kwargs) |
| 173 | + |
| 174 | + return wrapped_forward |
| 175 | + |
| 176 | + module.forward = make_wrapped_forward(original_forward) |
| 177 | + logger.debug(f"Wrapped {name} forward with local_map") |
| 178 | + |
| 179 | + |
110 | 180 | def _find_tied_parameters( |
111 | 181 | model: nn.Module, |
112 | 182 | ) -> list[list[tuple[nn.Module, str]]]: |
@@ -197,6 +267,11 @@ def distribute_model( |
197 | 267 | f"after distribute_module" |
198 | 268 | ) |
199 | 269 |
|
| 270 | + # Wrap inner attention modules' forward with local_map so that |
| 271 | + # DTensor inputs are converted to local tensors before attention |
| 272 | + # computation (e.g. flex attention mask functions expect plain tensors). |
| 273 | + _wrap_inner_attention_with_local_map(model) |
| 274 | + |
200 | 275 | logger.info( |
201 | 276 | f"Distributed model parameters as DTensors on SPMD mesh " |
202 | 277 | f"with dims {spmd_mesh.mesh_dim_names}" |
|
0 commit comments