diff --git a/tests/unit_tests/test_dataset_checkpointing.py b/tests/unit_tests/test_dataset_checkpointing.py index 07020a3554..737ea4e561 100644 --- a/tests/unit_tests/test_dataset_checkpointing.py +++ b/tests/unit_tests/test_dataset_checkpointing.py @@ -55,6 +55,10 @@ def test_c4_resumption(self): assert torch.equal( input_ids["input"], expected_input_ids["input"] ) + assert torch.equal( + input_ids["positions"], + expected_input_ids["positions"], + ) assert torch.equal(labels, expected_labels) def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank): diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index df64d83d24..47e5b55afb 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -185,6 +185,17 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} + # TODO: deduplicate with Trainer.post_dataloading_process which has + # the same logic; extract a shared function to prevent further drift. + # For causal attention the whole packed sequence is one document, + # so sequential RoPE positions (positions=None) are correct. + model_config = getattr(model_parts[0], "config", None) + layer = getattr(model_config, "layer", None) + attn_config = getattr(layer, "attention", None) if layer else None + attn_mask_type = getattr(attn_config, "attn_mask_type", "causal") + if attn_mask_type != "block_causal": + extra_inputs.pop("positions", None) + try: # pyrefly: ignore [not-callable] extra_kwargs["attention_masks"] = cast( diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index d5545452f1..aec4f15129 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -96,6 +96,7 @@ def __init__( # Variables for checkpointing self._sample_idx = 0 self._token_buffer: list[int] = [] + self._position_buffer: list[int] = [] def _get_data_iter(self): # For map-style datasets, resume by skipping to the correct index @@ -119,15 +120,27 @@ def __iter__(self): sample_text, add_bos=True, add_eos=True ) self._token_buffer.extend(sample_tokens) + # Per-document positions reset at document boundaries, + # matching inference frameworks (e.g. vLLM) that start + # positions at 0 per request. Positions wrap at seq_len + # to stay within the RoPE cache, effectively chunking + # long documents into seq_len-sized segments. + # TODO: make overflow policy configurable (chunk / truncate / drop). + self._position_buffer.extend( + i % self.seq_len for i in range(len(sample_tokens)) + ) self._sample_idx += 1 while len(self._token_buffer) >= max_buffer_token_len: x = torch.LongTensor(self._token_buffer[:max_buffer_token_len]) - # update tokens to the remaining tokens + pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len]) + # update buffers to the remaining tokens self._token_buffer = self._token_buffer[max_buffer_token_len:] + self._position_buffer = self._position_buffer[max_buffer_token_len:] input = x[:-1] label = x[1:] - yield {"input": input}, label + positions = pos[:-1] + yield {"input": input, "positions": positions}, label if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data") @@ -145,6 +158,15 @@ def __iter__(self): def load_state_dict(self, state_dict): self._token_buffer = state_dict["token_buffer"] + if "position_buffer" not in state_dict: + logger.warning( + "Checkpoint missing 'position_buffer' key in dataset state. " + "Falling back to empty position buffer. This is expected when " + "resuming from a checkpoint saved before position tracking was " + "added, but may cause incorrect RoPE positions with " + "block_causal attention (document packing)." + ) + self._position_buffer = state_dict.get("position_buffer", []) if isinstance(self._data, Dataset): self._sample_idx = state_dict["sample_idx"] @@ -153,7 +175,10 @@ def load_state_dict(self, state_dict): self._data.load_state_dict(state_dict["data"]) def state_dict(self): - _state_dict: dict[str, Any] = {"token_buffer": self._token_buffer} + _state_dict: dict[str, Any] = { + "token_buffer": self._token_buffer, + "position_buffer": self._position_buffer, + } if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx diff --git a/torchtitan/models/common/rope.py b/torchtitan/models/common/rope.py index ed19839252..8933a7900a 100644 --- a/torchtitan/models/common/rope.py +++ b/torchtitan/models/common/rope.py @@ -9,6 +9,7 @@ from typing import Literal import torch +from torch.distributed.tensor import DTensor, Replicate, Shard from torchtitan.protocols.module import Module @@ -289,6 +290,43 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) +def _maybe_wrap_positions( + positions: torch.Tensor | None, + x: torch.Tensor, +) -> torch.Tensor | None: + """Wrap positions as a DTensor deriving mesh and placements from x (xq/xk). + + TODO: In a full DTensor rewrite, positions should be made a DTensor + in/right after dataloading, together with inputs and labels. + + When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS), + x is a DTensor but positions is a plain tensor. The downstream + torch.gather requires both operands to be the same type. + + Positions (bsz, seqlen) has fewer dimensions than x (bsz, seqlen, + n_heads, head_dim), so we only preserve Shard placements for shared + dimensions. Shard dims beyond positions' rank (e.g. Shard(2) for TP + on heads) become Replicate. + """ + if ( + positions is not None + and isinstance(x, DTensor) + and not isinstance(positions, DTensor) + ): + ndim = positions.ndim + placements = tuple( + p if not isinstance(p, Shard) or p.dim < ndim else Replicate() + for p in x.placements + ) + positions = DTensor.from_local( + positions, + x.device_mesh, + placements, + run_check=False, + ) + return positions + + # TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex def apply_rotary_emb_complex( xq: torch.Tensor, @@ -304,6 +342,7 @@ def apply_rotary_emb_complex( freqs_cis: (max_seqlen, head_dim // 2) complex positions: optional position indices """ + positions = _maybe_wrap_positions(positions, xq) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions) @@ -324,6 +363,7 @@ def apply_rotary_emb_single_complex( freqs_cis: (max_seqlen, head_dim // 2) complex positions: optional position indices """ + positions = _maybe_wrap_positions(positions, x) dtype = x.dtype x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions) @@ -345,6 +385,7 @@ def apply_rotary_emb_cos_sin( rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated positions: optional position indices """ + positions = _maybe_wrap_positions(positions, xq) head_dim = xq.shape[-1] rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions) cos = rope_cache[..., :head_dim].to(device=xq.device) diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index 6d0c1b564d..fe15f0271e 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -591,9 +591,14 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} - # TODO: improve the logic on obtaining attention masks + # For causal attention the whole packed sequence is one document, + # so sequential RoPE positions (positions=None) are correct. layer = getattr(self.model_config, "layer", None) attn_config = getattr(layer, "attention", None) if layer else None + attn_mask_type = getattr(attn_config, "attn_mask_type", "causal") + if attn_mask_type != "block_causal": + extra_inputs.pop("positions", None) + attn_backend = getattr(attn_config, "attn_backend", "sdpa") if attn_backend in ["flex", "varlen"]: assert (