Skip to content

Commit 4e6888b

Browse files
committed
fix: wrap per-document RoPE positions at seq_len to prevent OOB gather
Documents longer than seq_len produce position IDs that exceed the RoPE cache size, causing an index-out-of-bounds error in torch.gather during apply_rotary_emb. Wrap positions with modulo seq_len in the dataloader, which effectively chunks long documents for RoPE purposes while preserving all tokens for training. Also update comments to clarify: per-document positions are dropped for causal attention (whole sequence is one document), and kept for block_causal to match inference frameworks (e.g. vLLM) that reset positions to 0 per request.
1 parent d2433d5 commit 4e6888b

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

torchtitan/components/validate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,8 @@ def post_dataloading_process(
187187

188188
# TODO: deduplicate with Trainer.post_dataloading_process which has
189189
# the same logic; extract a shared function to prevent further drift.
190-
# Per-document position IDs are only needed for block_causal
191-
# attention, where each packed document gets its own RoPE reset.
192-
# For causal attention the whole sequence is one document, so
193-
# sequential positions (the positions=None path) are correct.
194-
# Passing them through would also OOB the RoPE cache, since
195-
# individual document lengths can exceed max_seq_len.
190+
# For causal attention the whole packed sequence is one document,
191+
# so sequential RoPE positions (positions=None) are correct.
196192
model_config = getattr(model_parts[0], "config", None)
197193
layer = getattr(model_config, "layer", None)
198194
attn_config = getattr(layer, "attention", None) if layer else None

torchtitan/hf_datasets/text_datasets.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,15 @@ def __iter__(self):
120120
sample_text, add_bos=True, add_eos=True
121121
)
122122
self._token_buffer.extend(sample_tokens)
123-
self._position_buffer.extend(range(len(sample_tokens)))
123+
# Per-document positions reset at document boundaries,
124+
# matching inference frameworks (e.g. vLLM) that start
125+
# positions at 0 per request. Positions wrap at seq_len
126+
# to stay within the RoPE cache, effectively chunking
127+
# long documents into seq_len-sized segments.
128+
# TODO: make overflow policy configurable (chunk / truncate / drop).
129+
self._position_buffer.extend(
130+
i % self.seq_len for i in range(len(sample_tokens))
131+
)
124132
self._sample_idx += 1
125133

126134
while len(self._token_buffer) >= max_buffer_token_len:

torchtitan/trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,12 +591,8 @@ def post_dataloading_process(
591591
# extra_kwargs are.
592592
extra_kwargs: dict[str, Any] = {}
593593

594-
# Per-document position IDs are only needed for block_causal
595-
# attention, where each packed document gets its own RoPE reset.
596-
# For causal attention the whole sequence is one document, so
597-
# sequential positions (the positions=None path) are correct.
598-
# Passing them through would also OOB the RoPE cache, since
599-
# individual document lengths can exceed max_seq_len.
594+
# For causal attention the whole packed sequence is one document,
595+
# so sequential RoPE positions (positions=None) are correct.
600596
layer = getattr(self.model_config, "layer", None)
601597
attn_config = getattr(layer, "attention", None) if layer else None
602598
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")

0 commit comments

Comments
 (0)