Yield per-document RoPE position ids from dataset#2560
Yield per-document RoPE position ids from dataset#2560joecummings merged 8 commits intopytorch:mainfrom
Conversation
|
cc @tianyu-l @francesco-bertolotti : I did the fix that was discussed in #2559, but the "longer term fix" is also pretty simple. I might suggest we just do that in this PR, unless you have objections b/c that would technically be changing the behavior of the attention mask construction. Could be a follow up. |
|
@joecummings
So you are suggesting putting it in dataloading. But then for more complicated, model-specific mask generation (e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/model.py#L209), there still need to be this post_dataloading_processing https://github.com/pytorch/torchtitan/blob/main/torchtitan/trainer.py#L608, right? |
i think this is expected for rope |
|
After realizing that the default testing is also incorrect b/c they rely on spda backend even with packing, opting to just get this fix in first and consider a proper refactor after. |
0b0f7d7 to
b3d7f60
Compare
Moved this to a different PR since I just want to get this fix in first before tackling the larger problem with packing in torchtitan. Let's have the discussion there when ready |
Add a position buffer that tracks per-document RoPE positions, resetting at each document boundary. These positions are yielded alongside input tokens and used when block_causal attention is configured. Also add is_packed validation to catch misconfigured attention backends at trainer init time: packed dataloaders require flex or varlen with block_causal to prevent cross-document attention leakage.
…te kwarg with CP+PP Validator.post_dataloading_process was missing the guard that Trainer.post_dataloading_process has to remove `positions` from extra_inputs when attn_mask_type != "block_causal". When CP is enabled, prepare_context_parallel_input adds its own sharded `positions` to extra_kwargs, so leaving the original in extra_inputs caused pp_schedule.eval() to receive `positions` twice.
The comments blamed DTensor+FSDP for the positions guard, but the actual issue is an out-of-bounds RoPE cache index: per-document position IDs from packed datasets can exceed max_seq_len (e.g. 6545 vs cache size 2048). The guard is also semantically correct — causal attention treats the packed sequence as one document, so sequential positions via the None path are what we want.
Documents longer than seq_len produce position IDs that exceed the RoPE cache size, causing an index-out-of-bounds error in torch.gather during apply_rotary_emb. Wrap positions with modulo seq_len in the dataloader, which effectively chunks long documents for RoPE purposes while preserving all tokens for training. Also update comments to clarify: per-document positions are dropped for causal attention (whole sequence is one document), and kept for block_causal to match inference frameworks (e.g. vLLM) that reset positions to 0 per request.
…uffer When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS), freqs_cis becomes a DTensor(Replicate) but positions remains a plain tensor. torch.gather requires both operands to be the same type, causing a runtime error. Fix by wrapping positions via DTensor.from_local() at the apply_rotary_emb public API boundary. Also add a logger.warning when loading a checkpoint that is missing the position_buffer key in the dataset state dict, to help users debug incorrect RoPE positions when resuming from older checkpoints.
torchtitan/models/common/rope.py
Outdated
| freqs_cis: (max_seqlen, head_dim // 2) complex | ||
| positions: optional position indices | ||
| """ | ||
| positions = _maybe_wrap_positions(positions, freqs_cis) |
There was a problem hiding this comment.
Public API boundary seemed like the proper place to do this wrapping, but lmk if you'd prefer this somewhere else @tianyu-l
torchtitan/models/common/rope.py
Outdated
| positions = DTensor.from_local( | ||
| positions, | ||
| freqs_cis.device_mesh, | ||
| freqs_cis.placements, |
There was a problem hiding this comment.
Usually positions should have the same placements as x, rather than freqs_cis. We are not wrapping tensors on CP dimension, but if we do, x and positions will be sharded on sequence dim on CP, whereas freqs_cis will be Replicate on CP.
There was a problem hiding this comment.
I will change to borrow placements from x
There was a problem hiding this comment.
sorry, let me clarify:
x has more dimensions than positions. The placement should match on the dimensions they share (batch, sequence). But if x is sharded on extra dimensions (e.g. in TP x would be sharded on head_dim namely with placement Shard(3)) then the corresponding placement on positions should be Replicate.
There was a problem hiding this comment.
I see - in that case, I'll leave the function name more specific since it is truly tied to both x's state and whether or not we have positions
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def _maybe_wrap_positions( |
There was a problem hiding this comment.
maybe
| def _maybe_wrap_positions( | |
| def _maybe_to_dtensor( |
Positions are per-token like the input activations, so they should share placements with x (xq/xk) rather than freqs_cis. This is forward-compatible with CP where positions would be Shard(seq) like x, while freqs_cis remains Replicate. Also rename to _maybe_to_dtensor(tensor, like) for clarity.
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def _maybe_wrap_positions( |
There was a problem hiding this comment.
sg, please leave a TODO: in full DTensor rewrite, we should make positions a DTensor in/right after dataloading, together with inputs and labels.
cc @fegin
Positions are per-token like the input activations, so they should share placements with x (xq/xk) rather than freqs_cis. This is forward-compatible with CP where positions would be Shard(seq) like x, while freqs_cis remains Replicate. Since positions (bsz, seqlen) has fewer dims than x (bsz, seqlen, n_heads, head_dim), Shard placements beyond positions' rank (e.g. Shard(2) for TP on heads) are demoted to Replicate.
Fixes #2559
HuggingFaceTextDatasetnow tracks a_position_bufferalongside the existing_token_buffer.Each document's tokens get positions [0, 1, ..., doc_len-1], resetting at every document boundary. Positions are yielded as {"input": input, "positions": positions} and flow through the trainer'sextra_inputsintoDecoder.forward(positions=...)automatically.Checkpoint state_dict/load_state_dict updated to persist the position buffer (BC via .get()).
Longer-term consideration
Right now there are two considerations for packed datasets: attention masks and position IDs. Attention masks are computed in the post_dataloading_process and, in this PR, position IDs are built in the dataset. Constructing masks purely based on EOS token id is fragile, especially with post-training multi-turn sequences where models could co-opts that token for end of sequence versus end of document.
The right long-term approach for torchtitan is that datasets yield
seq_lensmetadata alongside tokens (rather thanposition_idsdirectely), and both positions and attention masks are derived from that single source of truth in post-processing. This would retire the EOS-basedget_document_mask_modpath entirely and co-locate both computations in one place.Doesn't change how Decoder works.
Resources: https://github.com/NVIDIA/NeMo/blob/v2.7.0/nemo/collections/llm/gpt/data/core.py, https://github.com/pytorch/torchtune/blob/d0f63bb33d00b8bd3905a010b71d8c6324c2e980/torchtune/datasets/_packed.py#L108-L143,
Or could possible switch everything to
varlenbut that would ensure that FA4 is working and the additional lift when integrating with context parallel.Test plan
Unit tests pass
Also for fun, comparison between WITH position ids and WITHOUT. Definitely different in the loss, but not by a ton:

Next steps
Next steps
First is #2610: Switch llama3 and qwen3 model configs from sdpa/causal to flex/block_causal, Regenerate expected loss files (
tests/assets/losses/llama3_cuda.txt,llama3_rocm.txt) with new defaults, Remove now-redundant config variants (debugmodel_flex_attn,8B_flex,debugmodel_flex)Later stuff like:
prepare_context_parallel_inputoverwrites positions with sequentialtorch.arange)get_attention_masks