99from typing import Literal
1010
1111import torch
12- from torch .distributed .tensor import DTensor
12+ from torch .distributed .tensor import DTensor , Replicate , Shard
1313
1414from torchtitan .protocols .module import Module
1515
@@ -290,31 +290,38 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
290290 return torch .cat ((- x2 , x1 ), dim = - 1 )
291291
292292
293- def _maybe_to_dtensor (
294- tensor : torch .Tensor | None ,
295- like : torch .Tensor ,
293+ def _maybe_wrap_positions (
294+ positions : torch .Tensor | None ,
295+ x : torch .Tensor ,
296296) -> torch .Tensor | None :
297- """Convert tensor to a DTensor matching like's mesh and placements.
297+ """Wrap positions as a DTensor deriving mesh and placements from x (xq/xk) .
298298
299299 When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
300- like (xq/xk) is a DTensor but tensor (positions) is a plain tensor.
301- The downstream torch.gather requires both operands to be the same type.
302- Positions should share placements with the input activations rather than
303- freqs_cis, because both are per-token and would be sharded the same way
304- under CP.
300+ x is a DTensor but positions is a plain tensor. The downstream
301+ torch.gather requires both operands to be the same type.
302+
303+ Positions (bsz, seqlen) has fewer dimensions than x (bsz, seqlen,
304+ n_heads, head_dim), so we only preserve Shard placements for shared
305+ dimensions. Shard dims beyond positions' rank (e.g. Shard(2) for TP
306+ on heads) become Replicate.
305307 """
306308 if (
307- tensor is not None
308- and isinstance (like , DTensor )
309- and not isinstance (tensor , DTensor )
309+ positions is not None
310+ and isinstance (x , DTensor )
311+ and not isinstance (positions , DTensor )
310312 ):
311- tensor = DTensor .from_local (
312- tensor ,
313- like .device_mesh ,
314- like .placements ,
313+ ndim = positions .ndim
314+ placements = tuple (
315+ p if not isinstance (p , Shard ) or p .dim < ndim else Replicate ()
316+ for p in x .placements
317+ )
318+ positions = DTensor .from_local (
319+ positions ,
320+ x .device_mesh ,
321+ placements ,
315322 run_check = False ,
316323 )
317- return tensor
324+ return positions
318325
319326
320327# TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex
@@ -332,7 +339,7 @@ def apply_rotary_emb_complex(
332339 freqs_cis: (max_seqlen, head_dim // 2) complex
333340 positions: optional position indices
334341 """
335- positions = _maybe_to_dtensor (positions , xq )
342+ positions = _maybe_wrap_positions (positions , xq )
336343 xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ))
337344 xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ))
338345 freqs_cis = _reshape_for_broadcast_complex (freqs_cis , xq_ , positions )
@@ -353,7 +360,7 @@ def apply_rotary_emb_single_complex(
353360 freqs_cis: (max_seqlen, head_dim // 2) complex
354361 positions: optional position indices
355362 """
356- positions = _maybe_to_dtensor (positions , x )
363+ positions = _maybe_wrap_positions (positions , x )
357364 dtype = x .dtype
358365 x = torch .view_as_complex (x .float ().view (* x .shape [:- 1 ], - 1 , 2 ))
359366 freqs_cis = _reshape_for_broadcast_complex (freqs_cis , x , positions )
@@ -375,7 +382,7 @@ def apply_rotary_emb_cos_sin(
375382 rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated
376383 positions: optional position indices
377384 """
378- positions = _maybe_to_dtensor (positions , xq )
385+ positions = _maybe_wrap_positions (positions , xq )
379386 head_dim = xq .shape [- 1 ]
380387 rope_cache = _reshape_for_broadcast_cos_sin (rope_cache , xq , positions )
381388 cos = rope_cache [..., :head_dim ].to (device = xq .device )
0 commit comments