Skip to content

Commit d398702

Browse files
committed
fix: derive positions DTensor placements from x instead of freqs_cis
Positions are per-token like the input activations, so they should share placements with x (xq/xk) rather than freqs_cis. This is forward-compatible with CP where positions would be Shard(seq) like x, while freqs_cis remains Replicate. Since positions (bsz, seqlen) has fewer dims than x (bsz, seqlen, n_heads, head_dim), Shard placements beyond positions' rank (e.g. Shard(2) for TP on heads) are demoted to Replicate.
1 parent 3e23b47 commit d398702

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

torchtitan/models/common/rope.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Literal
1010

1111
import torch
12-
from torch.distributed.tensor import DTensor
12+
from torch.distributed.tensor import DTensor, Replicate, Shard
1313

1414
from torchtitan.protocols.module import Module
1515

@@ -290,31 +290,38 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
290290
return torch.cat((-x2, x1), dim=-1)
291291

292292

293-
def _maybe_to_dtensor(
294-
tensor: torch.Tensor | None,
295-
like: torch.Tensor,
293+
def _maybe_wrap_positions(
294+
positions: torch.Tensor | None,
295+
x: torch.Tensor,
296296
) -> torch.Tensor | None:
297-
"""Convert tensor to a DTensor matching like's mesh and placements.
297+
"""Wrap positions as a DTensor deriving mesh and placements from x (xq/xk).
298298
299299
When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
300-
like (xq/xk) is a DTensor but tensor (positions) is a plain tensor.
301-
The downstream torch.gather requires both operands to be the same type.
302-
Positions should share placements with the input activations rather than
303-
freqs_cis, because both are per-token and would be sharded the same way
304-
under CP.
300+
x is a DTensor but positions is a plain tensor. The downstream
301+
torch.gather requires both operands to be the same type.
302+
303+
Positions (bsz, seqlen) has fewer dimensions than x (bsz, seqlen,
304+
n_heads, head_dim), so we only preserve Shard placements for shared
305+
dimensions. Shard dims beyond positions' rank (e.g. Shard(2) for TP
306+
on heads) become Replicate.
305307
"""
306308
if (
307-
tensor is not None
308-
and isinstance(like, DTensor)
309-
and not isinstance(tensor, DTensor)
309+
positions is not None
310+
and isinstance(x, DTensor)
311+
and not isinstance(positions, DTensor)
310312
):
311-
tensor = DTensor.from_local(
312-
tensor,
313-
like.device_mesh,
314-
like.placements,
313+
ndim = positions.ndim
314+
placements = tuple(
315+
p if not isinstance(p, Shard) or p.dim < ndim else Replicate()
316+
for p in x.placements
317+
)
318+
positions = DTensor.from_local(
319+
positions,
320+
x.device_mesh,
321+
placements,
315322
run_check=False,
316323
)
317-
return tensor
324+
return positions
318325

319326

320327
# TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex
@@ -332,7 +339,7 @@ def apply_rotary_emb_complex(
332339
freqs_cis: (max_seqlen, head_dim // 2) complex
333340
positions: optional position indices
334341
"""
335-
positions = _maybe_to_dtensor(positions, xq)
342+
positions = _maybe_wrap_positions(positions, xq)
336343
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
337344
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
338345
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions)
@@ -353,7 +360,7 @@ def apply_rotary_emb_single_complex(
353360
freqs_cis: (max_seqlen, head_dim // 2) complex
354361
positions: optional position indices
355362
"""
356-
positions = _maybe_to_dtensor(positions, x)
363+
positions = _maybe_wrap_positions(positions, x)
357364
dtype = x.dtype
358365
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
359366
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions)
@@ -375,7 +382,7 @@ def apply_rotary_emb_cos_sin(
375382
rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated
376383
positions: optional position indices
377384
"""
378-
positions = _maybe_to_dtensor(positions, xq)
385+
positions = _maybe_wrap_positions(positions, xq)
379386
head_dim = xq.shape[-1]
380387
rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions)
381388
cos = rope_cache[..., :head_dim].to(device=xq.device)

0 commit comments

Comments
 (0)