Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 78 additions & 3 deletions torchtitan/distributed/full_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
# This file contains utility functions to parallelize models with full dtensor.
# We will eventually replace the existing functions with these or merge them.

from collections.abc import Callable
from typing import Any

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import DataParallelMeshDims
from torch.distributed.tensor import distribute_module, DTensor

from torch.distributed.tensor.experimental import local_map
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard

from torchtitan.distributed.parallel_dims import ParallelDims
Expand Down Expand Up @@ -52,11 +55,11 @@ def validate_config(parallel_dims: ParallelDims, model_config: Any) -> None:
layer = getattr(model_config, "layer", None)
attn_config = getattr(layer, "attention", None) if layer else None
attn_backend = getattr(attn_config, "attn_backend", "sdpa")
if attn_backend in ("flex", "varlen"):
if attn_backend == "varlen":
raise NotImplementedError(
f"full_dtensor is not supported with {attn_backend} attention. "
"Flex/varlen attention does not support DTensor dispatch. "
"Use sdpa attention or disable full_dtensor."
"Varlen attention does not support DTensor dispatch. "
"Use sdpa or flex attention, or disable full_dtensor."
)


Expand Down Expand Up @@ -107,6 +110,73 @@ def _remove_sdpa_math_backend(model: nn.Module) -> None:
]


def _wrap_inner_attention_with_local_map(model: nn.Module) -> None:
"""Wrap inner attention modules' forward with local_map.

In full DTensor mode, attention inputs (q, k, v) are DTensors. Flex
attention mask functions (e.g. document_mask) operate on plain tensors
and cannot handle DTensor indices. Wrapping forward with local_map
converts DTensor inputs to local tensors before forward runs, and wraps
outputs back to DTensors afterward.

Placements and device mesh are inferred from the input DTensors at
runtime, following the sixlib pattern of wrapping self.forward at setup
time. The local_map wrapper is lazily created on first call since FSDP
may change placements after distribute_module.
"""
from torchtitan.models.common.attention import (
FlexAttentionWrapper,
ScaledDotProductAttentionWrapper,
VarlenAttentionWrapper,
)

inner_attn_types = (
FlexAttentionWrapper,
ScaledDotProductAttentionWrapper,
VarlenAttentionWrapper,
)

for name, module in model.named_modules():
if isinstance(module, inner_attn_types):
original_forward = module.forward

# Cache for the local_map-wrapped function, keyed by placements
# so we don't recreate it every call.
cached_fn: dict[tuple, Callable] = {}

def make_wrapped_forward(orig_fn: Callable) -> Callable:
def wrapped_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
**kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not isinstance(q, DTensor):
return orig_fn(q, k, v, **kwargs)

assert isinstance(k, DTensor) and isinstance(v, DTensor)

# Infer placements from the input DTensors
cache_key = (q.placements, k.placements, v.placements)
if cache_key not in cached_fn:
cached_fn[cache_key] = local_map(
orig_fn,
out_placements=(q.placements,),
in_placements=(
q.placements,
k.placements,
v.placements,
),
device_mesh=q.device_mesh,
)
return cached_fn[cache_key](q, k, v, **kwargs)

return wrapped_forward

module.forward = make_wrapped_forward(original_forward)
logger.debug(f"Wrapped {name} forward with local_map")


def _find_tied_parameters(
model: nn.Module,
) -> list[list[tuple[nn.Module, str]]]:
Expand Down Expand Up @@ -197,6 +267,11 @@ def distribute_model(
f"after distribute_module"
)

# Wrap inner attention modules' forward with local_map so that
# DTensor inputs are converted to local tensors before attention
# computation (e.g. flex attention mask functions expect plain tensors).
_wrap_inner_attention_with_local_map(model)

logger.info(
f"Distributed model parameters as DTensors on SPMD mesh "
f"with dims {spmd_mesh.mesh_dim_names}"
Expand Down
Loading