|
11 | 11 | import torch |
12 | 12 | import torch.nn.functional as F |
13 | 13 | from torch import nn |
14 | | -from torch.distributed.tensor import DTensor |
| 14 | +from torch.distributed.tensor import DTensor, Shard |
15 | 15 | from torch.distributed.tensor.experimental import local_map |
16 | 16 | from torch.nn.attention import sdpa_kernel, SDPBackend |
17 | 17 | from torch.nn.attention.flex_attention import ( |
@@ -99,17 +99,23 @@ def __call__( |
99 | 99 | assert isinstance(k, DTensor) and isinstance( |
100 | 100 | v, DTensor |
101 | 101 | ), "q, k, v should all be DTensors" |
| 102 | + # All placements must be Shard. We set |
| 103 | + # out_placements and in_grad_placements equal to |
| 104 | + # in_placements below. This is only valid for attention |
| 105 | + # as qkv are sharded on head dim. CP is handled |
| 106 | + # independently by _ContextParallel hooks inside |
| 107 | + # nn.Module.__call__. |
| 108 | + for tensor, name in ((q, "q"), (k, "k"), (v, "v")): |
| 109 | + for p in tensor.placements: |
| 110 | + assert isinstance(p, Shard), ( |
| 111 | + f"LocalMapModule requires Shard placements, " |
| 112 | + f"but {name} has placement {p}" |
| 113 | + ) |
102 | 114 | if self._local_map_fn is None: |
103 | 115 | self._local_map_fn = local_map( |
104 | 116 | super().__call__, |
105 | 117 | in_placements=(q.placements, k.placements, v.placements), |
106 | 118 | out_placements=(q.placements,), |
107 | | - # For TP (Shard on heads dim), in_grad_placements always |
108 | | - # matches in_placements since each rank owns a distinct |
109 | | - # shard of heads and grads stay on the same shard. |
110 | | - # CP grad placements are not a concern here because local_map |
111 | | - # only operates on the TP mesh. CP is handled independently |
112 | | - # by _ContextParallel hooks inside nn.Module.__call__. |
113 | 119 | in_grad_placements=(q.placements, k.placements, v.placements), |
114 | 120 | device_mesh=q.device_mesh, |
115 | 121 | ) |
|
0 commit comments