Skip to content

Yield per-document RoPE position ids from dataset#2560

Merged
joecummings merged 8 commits intopytorch:mainfrom
joecummings:fix-pos-id
Mar 19, 2026
Merged

Yield per-document RoPE position ids from dataset#2560
joecummings merged 8 commits intopytorch:mainfrom
joecummings:fix-pos-id

Conversation

@joecummings
Copy link
Member

@joecummings joecummings commented Mar 12, 2026

Fixes #2559

HuggingFaceTextDataset now tracks a _position_buffer alongside the existing _token_buffer. Each document's tokens get positions [0, 1, ..., doc_len-1], resetting at every document boundary. Positions are yielded as {"input": input, "positions": positions} and flow through the trainer's extra_inputs into Decoder.forward(positions=...) automatically.

Checkpoint state_dict/load_state_dict updated to persist the position buffer (BC via .get()).

Longer-term consideration

Right now there are two considerations for packed datasets: attention masks and position IDs. Attention masks are computed in the post_dataloading_process and, in this PR, position IDs are built in the dataset. Constructing masks purely based on EOS token id is fragile, especially with post-training multi-turn sequences where models could co-opts that token for end of sequence versus end of document.

The right long-term approach for torchtitan is that datasets yield seq_lens metadata alongside tokens (rather than position_ids directely), and both positions and attention masks are derived from that single source of truth in post-processing. This would retire the EOS-based get_document_mask_mod path entirely and co-locate both computations in one place.

# In dataloader
def _iter_greedy_packed(self):
      for sample in self._get_data_iter():
          input_ids = self._tokenize(sample)
          self._pack_buffer_input.extend(input_ids)
          self._pack_seq_lens.append(len(input_ids))  # just track the length

# In post dataloading process
 if "seq_lens" in extra_inputs:
          seq_lens = extra_inputs.pop("seq_lens")
          extra_inputs["positions"] = positions_from_seq_lens(seq_lens)
          extra_kwargs["attention_masks"] = mask_from_seq_lens(seq_lens)

Doesn't change how Decoder works.

Resources: https://github.com/NVIDIA/NeMo/blob/v2.7.0/nemo/collections/llm/gpt/data/core.py, https://github.com/pytorch/torchtune/blob/d0f63bb33d00b8bd3905a010b71d8c6324c2e980/torchtune/datasets/_packed.py#L108-L143,

Or could possible switch everything to varlen but that would ensure that FA4 is working and the additional lift when integrating with context parallel.

Test plan

Unit tests pass

Also for fun, comparison between WITH position ids and WITHOUT. Definitely different in the loss, but not by a ton:
Screenshot 2026-03-12 at 5 13 11 PM

Next steps

Next steps

First is #2610: Switch llama3 and qwen3 model configs from sdpa/causal to flex/block_causal, Regenerate expected loss files (tests/assets/losses/llama3_cuda.txt, llama3_rocm.txt) with new defaults, Remove now-redundant config variants (debugmodel_flex_attn, 8B_flex, debugmodel_flex)

Later stuff like:

  • Validate loss convergence on C4 with block_causal + per-document positions vs baseline
    • Propagate per-document positions through context parallel (currently prepare_context_parallel_input overwrites positions with sequential torch.arange)
    • Consider deriving attention masks from position IDs rather than EOS scanning, removing the tokenizer dependency from get_attention_masks

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 12, 2026
@joecummings joecummings changed the title Yield per-document RoPE position IDs from HuggingFaceTextDataset Yield per-document RoPE position ids from dataset Mar 12, 2026
@joecummings
Copy link
Member Author

cc @tianyu-l @francesco-bertolotti : I did the fix that was discussed in #2559, but the "longer term fix" is also pretty simple. I might suggest we just do that in this PR, unless you have objections b/c that would technically be changing the behavior of the attention mask construction. Could be a follow up.

@joecummings joecummings marked this pull request as ready for review March 12, 2026 21:06
@tianyu-l
Copy link
Contributor

@joecummings
The long term fix sounds reasonable. It can also replace varlen metadata creation https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/common/attention.py#L322

co-locate both computations in one place

So you are suggesting putting it in dataloading. But then for more complicated, model-specific mask generation (e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/model.py#L209), there still need to be this post_dataloading_processing https://github.com/pytorch/torchtitan/blob/main/torchtitan/trainer.py#L608, right?

@rakkit
Copy link
Contributor

rakkit commented Mar 13, 2026

Also for fun, comparison between WITH position ids and WITHOUT. Definitely different in the loss, but not by a ton:

i think this is expected for rope

@joecummings
Copy link
Member Author

After realizing that the default testing is also incorrect b/c they rely on spda backend even with packing, opting to just get this fix in first and consider a proper refactor after.

Copy link

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks reasonable to me, i just wonder why 'flex' and not 'varlen'. I don't know the backends well enough and implications. Any thoughts?

@joecummings joecummings force-pushed the fix-pos-id branch 4 times, most recently from 0b0f7d7 to b3d7f60 Compare March 16, 2026 20:19
@joecummings
Copy link
Member Author

this looks reasonable to me, i just wonder why 'flex' and not 'varlen'. I don't know the backends well enough and implications. Any thoughts?

Moved this to a different PR since I just want to get this fix in first before tackling the larger problem with packing in torchtitan. Let's have the discussion there when ready

Copy link

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would just take a look at the red CI and see also if this changes the integration tests

Add a position buffer that tracks per-document RoPE positions,
resetting at each document boundary. These positions are yielded
alongside input tokens and used when block_causal attention is
configured.

Also add is_packed validation to catch misconfigured attention
backends at trainer init time: packed dataloaders require flex or
varlen with block_causal to prevent cross-document attention leakage.
…te kwarg with CP+PP

Validator.post_dataloading_process was missing the guard that
Trainer.post_dataloading_process has to remove `positions` from
extra_inputs when attn_mask_type != "block_causal". When CP is enabled,
prepare_context_parallel_input adds its own sharded `positions` to
extra_kwargs, so leaving the original in extra_inputs caused
pp_schedule.eval() to receive `positions` twice.
The comments blamed DTensor+FSDP for the positions guard, but the
actual issue is an out-of-bounds RoPE cache index: per-document
position IDs from packed datasets can exceed max_seq_len (e.g. 6545
vs cache size 2048). The guard is also semantically correct — causal
attention treats the packed sequence as one document, so sequential
positions via the None path are what we want.
Documents longer than seq_len produce position IDs that exceed the RoPE
cache size, causing an index-out-of-bounds error in torch.gather during
apply_rotary_emb. Wrap positions with modulo seq_len in the dataloader,
which effectively chunks long documents for RoPE purposes while
preserving all tokens for training.

Also update comments to clarify: per-document positions are dropped for
causal attention (whole sequence is one document), and kept for
block_causal to match inference frameworks (e.g. vLLM) that reset
positions to 0 per request.
…uffer

When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
freqs_cis becomes a DTensor(Replicate) but positions remains a plain
tensor. torch.gather requires both operands to be the same type, causing
a runtime error. Fix by wrapping positions via DTensor.from_local() at
the apply_rotary_emb public API boundary.

Also add a logger.warning when loading a checkpoint that is missing the
position_buffer key in the dataset state dict, to help users debug
incorrect RoPE positions when resuming from older checkpoints.
freqs_cis: (max_seqlen, head_dim // 2) complex
positions: optional position indices
"""
positions = _maybe_wrap_positions(positions, freqs_cis)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public API boundary seemed like the proper place to do this wrapping, but lmk if you'd prefer this somewhere else @tianyu-l

@joecummings joecummings requested a review from tianyu-l March 18, 2026 15:45
positions = DTensor.from_local(
positions,
freqs_cis.device_mesh,
freqs_cis.placements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually positions should have the same placements as x, rather than freqs_cis. We are not wrapping tensors on CP dimension, but if we do, x and positions will be sharded on sequence dim on CP, whereas freqs_cis will be Replicate on CP.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change to borrow placements from x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, let me clarify:
x has more dimensions than positions. The placement should match on the dimensions they share (batch, sequence). But if x is sharded on extra dimensions (e.g. in TP x would be sharded on head_dim namely with placement Shard(3)) then the corresponding placement on positions should be Replicate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - in that case, I'll leave the function name more specific since it is truly tied to both x's state and whether or not we have positions

return torch.cat((-x2, x1), dim=-1)


def _maybe_wrap_positions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
def _maybe_wrap_positions(
def _maybe_to_dtensor(

Positions are per-token like the input activations, so they should share
placements with x (xq/xk) rather than freqs_cis. This is forward-compatible
with CP where positions would be Shard(seq) like x, while freqs_cis remains
Replicate. Also rename to _maybe_to_dtensor(tensor, like) for clarity.
return torch.cat((-x2, x1), dim=-1)


def _maybe_wrap_positions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, please leave a TODO: in full DTensor rewrite, we should make positions a DTensor in/right after dataloading, together with inputs and labels.

cc @fegin

Positions are per-token like the input activations, so they should share
placements with x (xq/xk) rather than freqs_cis. This is forward-compatible
with CP where positions would be Shard(seq) like x, while freqs_cis remains
Replicate.

Since positions (bsz, seqlen) has fewer dims than x (bsz, seqlen, n_heads,
head_dim), Shard placements beyond positions' rank (e.g. Shard(2) for TP
on heads) are demoted to Replicate.
@joecummings joecummings merged commit 422b057 into pytorch:main Mar 19, 2026
19 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RoPE positions are never set

4 participants