Skip to content
4 changes: 4 additions & 0 deletions tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def test_c4_resumption(self):
assert torch.equal(
input_ids["input"], expected_input_ids["input"]
)
assert torch.equal(
input_ids["positions"],
expected_input_ids["positions"],
)
assert torch.equal(labels, expected_labels)

def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ def post_dataloading_process(
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

# TODO: deduplicate with Trainer.post_dataloading_process which has
# the same logic; extract a shared function to prevent further drift.
# For causal attention the whole packed sequence is one document,
# so sequential RoPE positions (positions=None) are correct.
model_config = getattr(model_parts[0], "config", None)
layer = getattr(model_config, "layer", None)
attn_config = getattr(layer, "attention", None) if layer else None
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
if attn_mask_type != "block_causal":
extra_inputs.pop("positions", None)

try:
# pyrefly: ignore [not-callable]
extra_kwargs["attention_masks"] = cast(
Expand Down
31 changes: 28 additions & 3 deletions torchtitan/hf_datasets/text_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
# Variables for checkpointing
self._sample_idx = 0
self._token_buffer: list[int] = []
self._position_buffer: list[int] = []

def _get_data_iter(self):
# For map-style datasets, resume by skipping to the correct index
Expand All @@ -119,15 +120,27 @@ def __iter__(self):
sample_text, add_bos=True, add_eos=True
)
self._token_buffer.extend(sample_tokens)
# Per-document positions reset at document boundaries,
# matching inference frameworks (e.g. vLLM) that start
# positions at 0 per request. Positions wrap at seq_len
# to stay within the RoPE cache, effectively chunking
# long documents into seq_len-sized segments.
# TODO: make overflow policy configurable (chunk / truncate / drop).
self._position_buffer.extend(
i % self.seq_len for i in range(len(sample_tokens))
)
self._sample_idx += 1

while len(self._token_buffer) >= max_buffer_token_len:
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
# update tokens to the remaining tokens
pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len])
# update buffers to the remaining tokens
self._token_buffer = self._token_buffer[max_buffer_token_len:]
self._position_buffer = self._position_buffer[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield {"input": input}, label
positions = pos[:-1]
yield {"input": input, "positions": positions}, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
Expand All @@ -145,6 +158,15 @@ def __iter__(self):

def load_state_dict(self, state_dict):
self._token_buffer = state_dict["token_buffer"]
if "position_buffer" not in state_dict:
logger.warning(
"Checkpoint missing 'position_buffer' key in dataset state. "
"Falling back to empty position buffer. This is expected when "
"resuming from a checkpoint saved before position tracking was "
"added, but may cause incorrect RoPE positions with "
"block_causal attention (document packing)."
)
self._position_buffer = state_dict.get("position_buffer", [])

if isinstance(self._data, Dataset):
self._sample_idx = state_dict["sample_idx"]
Expand All @@ -153,7 +175,10 @@ def load_state_dict(self, state_dict):
self._data.load_state_dict(state_dict["data"])

def state_dict(self):
_state_dict: dict[str, Any] = {"token_buffer": self._token_buffer}
_state_dict: dict[str, Any] = {
"token_buffer": self._token_buffer,
"position_buffer": self._position_buffer,
}

if isinstance(self._data, Dataset):
_state_dict["sample_idx"] = self._sample_idx
Expand Down
33 changes: 33 additions & 0 deletions torchtitan/models/common/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Literal

import torch
from torch.distributed.tensor import DTensor, Replicate

from torchtitan.protocols.module import Module

Expand Down Expand Up @@ -289,6 +290,35 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)


def _maybe_wrap_positions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
def _maybe_wrap_positions(
def _maybe_to_dtensor(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, please leave a TODO: in full DTensor rewrite, we should make positions a DTensor in/right after dataloading, together with inputs and labels.

cc @fegin

positions: torch.Tensor | None,
freqs_cis: torch.Tensor,
) -> torch.Tensor | None:
"""Wrap positions as a DTensor if freqs_cis is a DTensor.

When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
freqs_cis is a DTensor (Replicate) but positions is a plain tensor.
The downstream torch.gather requires both operands to be the same type.
Since positions (int64 indices) has no gradient, grad_placements is
not needed.
"""
if (
positions is not None
and isinstance(freqs_cis, DTensor)
and not isinstance(positions, DTensor)
):
assert all(
isinstance(p, Replicate) for p in freqs_cis.placements
), f"Expected Replicate placements on freqs_cis, got {freqs_cis.placements}"
positions = DTensor.from_local(
positions,
freqs_cis.device_mesh,
freqs_cis.placements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually positions should have the same placements as x, rather than freqs_cis. We are not wrapping tensors on CP dimension, but if we do, x and positions will be sharded on sequence dim on CP, whereas freqs_cis will be Replicate on CP.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change to borrow placements from x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, let me clarify:
x has more dimensions than positions. The placement should match on the dimensions they share (batch, sequence). But if x is sharded on extra dimensions (e.g. in TP x would be sharded on head_dim namely with placement Shard(3)) then the corresponding placement on positions should be Replicate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - in that case, I'll leave the function name more specific since it is truly tied to both x's state and whether or not we have positions

run_check=False,
)
return positions


# TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex
def apply_rotary_emb_complex(
xq: torch.Tensor,
Expand All @@ -304,6 +334,7 @@ def apply_rotary_emb_complex(
freqs_cis: (max_seqlen, head_dim // 2) complex
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, freqs_cis)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public API boundary seemed like the proper place to do this wrapping, but lmk if you'd prefer this somewhere else @tianyu-l

xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions)
Expand All @@ -324,6 +355,7 @@ def apply_rotary_emb_single_complex(
freqs_cis: (max_seqlen, head_dim // 2) complex
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, freqs_cis)
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions)
Expand All @@ -345,6 +377,7 @@ def apply_rotary_emb_cos_sin(
rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, rope_cache)
head_dim = xq.shape[-1]
rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions)
cos = rope_cache[..., :head_dim].to(device=xq.device)
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,14 @@ def post_dataloading_process(
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

# TODO: improve the logic on obtaining attention masks
# For causal attention the whole packed sequence is one document,
# so sequential RoPE positions (positions=None) are correct.
layer = getattr(self.model_config, "layer", None)
attn_config = getattr(layer, "attention", None) if layer else None
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
if attn_mask_type != "block_causal":
extra_inputs.pop("positions", None)

attn_backend = getattr(attn_config, "attn_backend", "sdpa")
if attn_backend in ["flex", "varlen"]:
assert (
Expand Down
Loading