diff --git a/torchtitan/distributed/full_dtensor.py b/torchtitan/distributed/full_dtensor.py index f8103093ce..66d713509e 100644 --- a/torchtitan/distributed/full_dtensor.py +++ b/torchtitan/distributed/full_dtensor.py @@ -7,6 +7,7 @@ # This file contains utility functions to parallelize models with full dtensor. # We will eventually replace the existing functions with these or merge them. +from collections.abc import Callable from typing import Any import torch @@ -14,6 +15,8 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import DataParallelMeshDims from torch.distributed.tensor import distribute_module, DTensor + +from torch.distributed.tensor.experimental import local_map from torch.distributed.tensor.placement_types import Placement, Replicate, Shard from torchtitan.distributed.parallel_dims import ParallelDims @@ -52,11 +55,11 @@ def validate_config(parallel_dims: ParallelDims, model_config: Any) -> None: layer = getattr(model_config, "layer", None) attn_config = getattr(layer, "attention", None) if layer else None attn_backend = getattr(attn_config, "attn_backend", "sdpa") - if attn_backend in ("flex", "varlen"): + if attn_backend == "varlen": raise NotImplementedError( f"full_dtensor is not supported with {attn_backend} attention. " - "Flex/varlen attention does not support DTensor dispatch. " - "Use sdpa attention or disable full_dtensor." + "Varlen attention does not support DTensor dispatch. " + "Use sdpa or flex attention, or disable full_dtensor." ) @@ -107,6 +110,73 @@ def _remove_sdpa_math_backend(model: nn.Module) -> None: ] +def _wrap_inner_attention_with_local_map(model: nn.Module) -> None: + """Wrap inner attention modules' forward with local_map. + + In full DTensor mode, attention inputs (q, k, v) are DTensors. Flex + attention mask functions (e.g. document_mask) operate on plain tensors + and cannot handle DTensor indices. Wrapping forward with local_map + converts DTensor inputs to local tensors before forward runs, and wraps + outputs back to DTensors afterward. + + Placements and device mesh are inferred from the input DTensors at + runtime, following the sixlib pattern of wrapping self.forward at setup + time. The local_map wrapper is lazily created on first call since FSDP + may change placements after distribute_module. + """ + from torchtitan.models.common.attention import ( + FlexAttentionWrapper, + ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + ) + + inner_attn_types = ( + FlexAttentionWrapper, + ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + ) + + for name, module in model.named_modules(): + if isinstance(module, inner_attn_types): + original_forward = module.forward + + # Cache for the local_map-wrapped function, keyed by placements + # so we don't recreate it every call. + cached_fn: dict[tuple, Callable] = {} + + def make_wrapped_forward(orig_fn: Callable) -> Callable: + def wrapped_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if not isinstance(q, DTensor): + return orig_fn(q, k, v, **kwargs) + + assert isinstance(k, DTensor) and isinstance(v, DTensor) + + # Infer placements from the input DTensors + cache_key = (q.placements, k.placements, v.placements) + if cache_key not in cached_fn: + cached_fn[cache_key] = local_map( + orig_fn, + out_placements=(q.placements,), + in_placements=( + q.placements, + k.placements, + v.placements, + ), + device_mesh=q.device_mesh, + ) + return cached_fn[cache_key](q, k, v, **kwargs) + + return wrapped_forward + + module.forward = make_wrapped_forward(original_forward) + logger.debug(f"Wrapped {name} forward with local_map") + + def _find_tied_parameters( model: nn.Module, ) -> list[list[tuple[nn.Module, str]]]: @@ -197,6 +267,11 @@ def distribute_model( f"after distribute_module" ) + # Wrap inner attention modules' forward with local_map so that + # DTensor inputs are converted to local tensors before attention + # computation (e.g. flex attention mask functions expect plain tensors). + _wrap_inner_attention_with_local_map(model) + logger.info( f"Distributed model parameters as DTensors on SPMD mesh " f"with dims {spmd_mesh.mesh_dim_names}"