Skip to content

Commit 3b1cde9

Browse files
committed
udpate comment
1 parent 48e1daf commit 3b1cde9

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torchtitan/models/common/attention.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn.functional as F
1313
from torch import nn
14-
from torch.distributed.tensor import DTensor
14+
from torch.distributed.tensor import DTensor, Shard
1515
from torch.distributed.tensor.experimental import local_map
1616
from torch.nn.attention import sdpa_kernel, SDPBackend
1717
from torch.nn.attention.flex_attention import (
@@ -99,17 +99,23 @@ def __call__(
9999
assert isinstance(k, DTensor) and isinstance(
100100
v, DTensor
101101
), "q, k, v should all be DTensors"
102+
# All placements must be Shard. We set
103+
# out_placements and in_grad_placements equal to
104+
# in_placements below. This is only valid for attention
105+
# as qkv are sharded on head dim. CP is handled
106+
# independently by _ContextParallel hooks inside
107+
# nn.Module.__call__.
108+
for tensor, name in ((q, "q"), (k, "k"), (v, "v")):
109+
for p in tensor.placements:
110+
assert isinstance(p, Shard), (
111+
f"LocalMapModule requires Shard placements, "
112+
f"but {name} has placement {p}"
113+
)
102114
if self._local_map_fn is None:
103115
self._local_map_fn = local_map(
104116
super().__call__,
105117
in_placements=(q.placements, k.placements, v.placements),
106118
out_placements=(q.placements,),
107-
# For TP (Shard on heads dim), in_grad_placements always
108-
# matches in_placements since each rank owns a distinct
109-
# shard of heads and grads stay on the same shard.
110-
# CP grad placements are not a concern here because local_map
111-
# only operates on the TP mesh. CP is handled independently
112-
# by _ContextParallel hooks inside nn.Module.__call__.
113119
in_grad_placements=(q.placements, k.placements, v.placements),
114120
device_mesh=q.device_mesh,
115121
)

0 commit comments

Comments
 (0)