Skip to content

Commit 7bc29c9

Browse files
committed
fix: wrap per-document RoPE positions at seq_len to prevent OOB gather
Documents longer than seq_len produce position IDs that exceed the RoPE cache size, causing an index-out-of-bounds error in torch.gather during apply_rotary_emb. Wrap positions with modulo seq_len in the dataloader, which effectively chunks long documents for RoPE purposes while preserving all tokens for training. Also update comments to clarify: per-document positions are dropped for causal attention (whole sequence is one document), and kept for block_causal to match inference frameworks (e.g. vLLM) that reset positions to 0 per request.
1 parent 782a2a6 commit 7bc29c9

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

torchtitan/components/validate.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,8 @@ def post_dataloading_process(
185185
# extra_kwargs are.
186186
extra_kwargs: dict[str, Any] = {}
187187

188-
# TODO: deduplicate with Trainer.post_dataloading_process which has
189-
# the same logic; extract a shared function to prevent further drift.
190-
# Per-document position IDs are only needed for block_causal
191-
# attention, where each packed document gets its own RoPE reset.
192-
# For causal attention the whole sequence is one document, so
193-
# sequential positions (the positions=None path) are correct.
194-
# Passing them through would also OOB the RoPE cache, since
195-
# individual document lengths can exceed max_seq_len.
188+
# For causal attention the whole packed sequence is one document,
189+
# so sequential RoPE positions (positions=None) are correct.
196190
model_config = getattr(model_parts[0], "config", None)
197191
layer = getattr(model_config, "layer", None)
198192
attn_config = getattr(layer, "attention", None) if layer else None

torchtitan/hf_datasets/text_datasets.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,15 @@ def __iter__(self):
120120
sample_text, add_bos=True, add_eos=True
121121
)
122122
self._token_buffer.extend(sample_tokens)
123-
self._position_buffer.extend(range(len(sample_tokens)))
123+
# Per-document positions reset at document boundaries,
124+
# matching inference frameworks (e.g. vLLM) that start
125+
# positions at 0 per request. Positions wrap at seq_len
126+
# to stay within the RoPE cache, effectively chunking
127+
# long documents into seq_len-sized segments.
128+
# TODO: make overflow policy configurable (chunk / truncate / drop).
129+
self._position_buffer.extend(
130+
i % self.seq_len for i in range(len(sample_tokens))
131+
)
124132
self._sample_idx += 1
125133

126134
while len(self._token_buffer) >= max_buffer_token_len:

torchtitan/trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -596,12 +596,8 @@ def post_dataloading_process(
596596
# extra_kwargs are.
597597
extra_kwargs: dict[str, Any] = {}
598598

599-
# Per-document position IDs are only needed for block_causal
600-
# attention, where each packed document gets its own RoPE reset.
601-
# For causal attention the whole sequence is one document, so
602-
# sequential positions (the positions=None path) are correct.
603-
# Passing them through would also OOB the RoPE cache, since
604-
# individual document lengths can exceed max_seq_len.
599+
# For causal attention the whole packed sequence is one document,
600+
# so sequential RoPE positions (positions=None) are correct.
605601
layer = getattr(self.model_config, "layer", None)
606602
attn_config = getattr(layer, "attention", None) if layer else None
607603
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")

0 commit comments

Comments
 (0)