Skip to content

Commit 13c1d03

Browse files
committed
fix: wrap positions as DTensor in RoPE and warn on missing position_buffer
When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS), freqs_cis becomes a DTensor(Replicate) but positions remains a plain tensor. torch.gather requires both operands to be the same type, causing a runtime error. Fix by wrapping positions via DTensor.from_local() at the apply_rotary_emb public API boundary. Also add a logger.warning when loading a checkpoint that is missing the position_buffer key in the dataset state dict, to help users debug incorrect RoPE positions when resuming from older checkpoints.
1 parent 4e6888b commit 13c1d03

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

torchtitan/hf_datasets/text_datasets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ def __iter__(self):
158158

159159
def load_state_dict(self, state_dict):
160160
self._token_buffer = state_dict["token_buffer"]
161+
if "position_buffer" not in state_dict:
162+
logger.warning(
163+
"Checkpoint missing 'position_buffer' key in dataset state. "
164+
"Falling back to empty position buffer. This is expected when "
165+
"resuming from a checkpoint saved before position tracking was "
166+
"added, but may cause incorrect RoPE positions with "
167+
"block_causal attention (document packing)."
168+
)
161169
self._position_buffer = state_dict.get("position_buffer", [])
162170

163171
if isinstance(self._data, Dataset):

torchtitan/models/common/rope.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Literal
1010

1111
import torch
12+
from torch.distributed.tensor import DTensor, Replicate
1213

1314
from torchtitan.protocols.module import Module
1415

@@ -289,6 +290,35 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
289290
return torch.cat((-x2, x1), dim=-1)
290291

291292

293+
def _maybe_wrap_positions(
294+
positions: torch.Tensor | None,
295+
freqs_cis: torch.Tensor,
296+
) -> torch.Tensor | None:
297+
"""Wrap positions as a DTensor if freqs_cis is a DTensor.
298+
299+
When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
300+
freqs_cis is a DTensor (Replicate) but positions is a plain tensor.
301+
The downstream torch.gather requires both operands to be the same type.
302+
Since positions (int64 indices) has no gradient, grad_placements is
303+
not needed.
304+
"""
305+
if (
306+
positions is not None
307+
and isinstance(freqs_cis, DTensor)
308+
and not isinstance(positions, DTensor)
309+
):
310+
assert all(
311+
isinstance(p, Replicate) for p in freqs_cis.placements
312+
), f"Expected Replicate placements on freqs_cis, got {freqs_cis.placements}"
313+
positions = DTensor.from_local(
314+
positions,
315+
freqs_cis.device_mesh,
316+
freqs_cis.placements,
317+
run_check=False,
318+
)
319+
return positions
320+
321+
292322
# TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex
293323
def apply_rotary_emb_complex(
294324
xq: torch.Tensor,
@@ -304,6 +334,7 @@ def apply_rotary_emb_complex(
304334
freqs_cis: (max_seqlen, head_dim // 2) complex
305335
positions: optional position indices
306336
"""
337+
positions = _maybe_wrap_positions(positions, freqs_cis)
307338
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
308339
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
309340
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions)
@@ -324,6 +355,7 @@ def apply_rotary_emb_single_complex(
324355
freqs_cis: (max_seqlen, head_dim // 2) complex
325356
positions: optional position indices
326357
"""
358+
positions = _maybe_wrap_positions(positions, freqs_cis)
327359
dtype = x.dtype
328360
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
329361
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions)
@@ -345,6 +377,7 @@ def apply_rotary_emb_cos_sin(
345377
rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated
346378
positions: optional position indices
347379
"""
380+
positions = _maybe_wrap_positions(positions, rope_cache)
348381
head_dim = xq.shape[-1]
349382
rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions)
350383
cos = rope_cache[..., :head_dim].to(device=xq.device)

0 commit comments

Comments
 (0)