Skip to content

Commit d2433d5

Browse files
committed
fix: correct misleading TODO comments about positions guard
The comments blamed DTensor+FSDP for the positions guard, but the actual issue is an out-of-bounds RoPE cache index: per-document position IDs from packed datasets can exceed max_seq_len (e.g. 6545 vs cache size 2048). The guard is also semantically correct — causal attention treats the packed sequence as one document, so sequential positions via the None path are what we want.
1 parent 1d27557 commit d2433d5

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

torchtitan/components/validate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,12 @@ def post_dataloading_process(
187187

188188
# TODO: deduplicate with Trainer.post_dataloading_process which has
189189
# the same logic; extract a shared function to prevent further drift.
190-
# TODO: remove this guard once RoPE handles DTensor+positions.
191-
# The positions!=None path in RoPE uses torch.gather which fails
192-
# with DTensor+FSDP. For now, only pass positions through when
193-
# using flex/varlen + block_causal (where it's needed and works).
190+
# Per-document position IDs are only needed for block_causal
191+
# attention, where each packed document gets its own RoPE reset.
192+
# For causal attention the whole sequence is one document, so
193+
# sequential positions (the positions=None path) are correct.
194+
# Passing them through would also OOB the RoPE cache, since
195+
# individual document lengths can exceed max_seq_len.
194196
model_config = getattr(model_parts[0], "config", None)
195197
layer = getattr(model_config, "layer", None)
196198
attn_config = getattr(layer, "attention", None) if layer else None

torchtitan/trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,10 +591,12 @@ def post_dataloading_process(
591591
# extra_kwargs are.
592592
extra_kwargs: dict[str, Any] = {}
593593

594-
# TODO: remove this guard once RoPE handles DTensor+positions.
595-
# The positions!=None path in RoPE uses torch.gather which fails
596-
# with DTensor+FSDP. For now, only pass positions through when
597-
# using flex/varlen + block_causal (where it's needed and works).
594+
# Per-document position IDs are only needed for block_causal
595+
# attention, where each packed document gets its own RoPE reset.
596+
# For causal attention the whole sequence is one document, so
597+
# sequential positions (the positions=None path) are correct.
598+
# Passing them through would also OOB the RoPE cache, since
599+
# individual document lengths can exceed max_seq_len.
598600
layer = getattr(self.model_config, "layer", None)
599601
attn_config = getattr(layer, "attention", None) if layer else None
600602
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")

0 commit comments

Comments
 (0)