Skip to content
4 changes: 4 additions & 0 deletions tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def test_c4_resumption(self):
assert torch.equal(
input_ids["input"], expected_input_ids["input"]
)
assert torch.equal(
input_ids["positions"],
expected_input_ids["positions"],
)
assert torch.equal(labels, expected_labels)

def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ def post_dataloading_process(
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

# TODO: deduplicate with Trainer.post_dataloading_process which has
# the same logic; extract a shared function to prevent further drift.
# For causal attention the whole packed sequence is one document,
# so sequential RoPE positions (positions=None) are correct.
model_config = getattr(model_parts[0], "config", None)
layer = getattr(model_config, "layer", None)
attn_config = getattr(layer, "attention", None) if layer else None
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
if attn_mask_type != "block_causal":
extra_inputs.pop("positions", None)

try:
# pyrefly: ignore [not-callable]
extra_kwargs["attention_masks"] = cast(
Expand Down
31 changes: 28 additions & 3 deletions torchtitan/hf_datasets/text_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
# Variables for checkpointing
self._sample_idx = 0
self._token_buffer: list[int] = []
self._position_buffer: list[int] = []

def _get_data_iter(self):
# For map-style datasets, resume by skipping to the correct index
Expand All @@ -119,15 +120,27 @@ def __iter__(self):
sample_text, add_bos=True, add_eos=True
)
self._token_buffer.extend(sample_tokens)
# Per-document positions reset at document boundaries,
# matching inference frameworks (e.g. vLLM) that start
# positions at 0 per request. Positions wrap at seq_len
# to stay within the RoPE cache, effectively chunking
# long documents into seq_len-sized segments.
# TODO: make overflow policy configurable (chunk / truncate / drop).
self._position_buffer.extend(
i % self.seq_len for i in range(len(sample_tokens))
)
self._sample_idx += 1

while len(self._token_buffer) >= max_buffer_token_len:
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
# update tokens to the remaining tokens
pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len])
# update buffers to the remaining tokens
self._token_buffer = self._token_buffer[max_buffer_token_len:]
self._position_buffer = self._position_buffer[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield {"input": input}, label
positions = pos[:-1]
yield {"input": input, "positions": positions}, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
Expand All @@ -145,6 +158,15 @@ def __iter__(self):

def load_state_dict(self, state_dict):
self._token_buffer = state_dict["token_buffer"]
if "position_buffer" not in state_dict:
logger.warning(
"Checkpoint missing 'position_buffer' key in dataset state. "
"Falling back to empty position buffer. This is expected when "
"resuming from a checkpoint saved before position tracking was "
"added, but may cause incorrect RoPE positions with "
"block_causal attention (document packing)."
)
self._position_buffer = state_dict.get("position_buffer", [])

if isinstance(self._data, Dataset):
self._sample_idx = state_dict["sample_idx"]
Expand All @@ -153,7 +175,10 @@ def load_state_dict(self, state_dict):
self._data.load_state_dict(state_dict["data"])

def state_dict(self):
_state_dict: dict[str, Any] = {"token_buffer": self._token_buffer}
_state_dict: dict[str, Any] = {
"token_buffer": self._token_buffer,
"position_buffer": self._position_buffer,
}

if isinstance(self._data, Dataset):
_state_dict["sample_idx"] = self._sample_idx
Expand Down
41 changes: 41 additions & 0 deletions torchtitan/models/common/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Literal

import torch
from torch.distributed.tensor import DTensor, Replicate, Shard

from torchtitan.protocols.module import Module

Expand Down Expand Up @@ -289,6 +290,43 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)


def _maybe_wrap_positions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
def _maybe_wrap_positions(
def _maybe_to_dtensor(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, please leave a TODO: in full DTensor rewrite, we should make positions a DTensor in/right after dataloading, together with inputs and labels.

cc @fegin

positions: torch.Tensor | None,
x: torch.Tensor,
) -> torch.Tensor | None:
"""Wrap positions as a DTensor deriving mesh and placements from x (xq/xk).

TODO: In a full DTensor rewrite, positions should be made a DTensor
in/right after dataloading, together with inputs and labels.

When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
x is a DTensor but positions is a plain tensor. The downstream
torch.gather requires both operands to be the same type.

Positions (bsz, seqlen) has fewer dimensions than x (bsz, seqlen,
n_heads, head_dim), so we only preserve Shard placements for shared
dimensions. Shard dims beyond positions' rank (e.g. Shard(2) for TP
on heads) become Replicate.
"""
if (
positions is not None
and isinstance(x, DTensor)
and not isinstance(positions, DTensor)
):
ndim = positions.ndim
placements = tuple(
p if not isinstance(p, Shard) or p.dim < ndim else Replicate()
for p in x.placements
)
positions = DTensor.from_local(
positions,
x.device_mesh,
placements,
run_check=False,
)
return positions


# TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex
def apply_rotary_emb_complex(
xq: torch.Tensor,
Expand All @@ -304,6 +342,7 @@ def apply_rotary_emb_complex(
freqs_cis: (max_seqlen, head_dim // 2) complex
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, xq)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions)
Expand All @@ -324,6 +363,7 @@ def apply_rotary_emb_single_complex(
freqs_cis: (max_seqlen, head_dim // 2) complex
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, x)
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions)
Expand All @@ -345,6 +385,7 @@ def apply_rotary_emb_cos_sin(
rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, xq)
head_dim = xq.shape[-1]
rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions)
cos = rope_cache[..., :head_dim].to(device=xq.device)
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,14 @@ def post_dataloading_process(
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

# TODO: improve the logic on obtaining attention masks
# For causal attention the whole packed sequence is one document,
# so sequential RoPE positions (positions=None) are correct.
layer = getattr(self.model_config, "layer", None)
attn_config = getattr(layer, "attention", None) if layer else None
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
if attn_mask_type != "block_causal":
extra_inputs.pop("positions", None)

attn_backend = getattr(attn_config, "attn_backend", "sdpa")
if attn_backend in ["flex", "varlen"]:
assert (
Expand Down
Loading