Skip to content

Commit 26678f6

Browse files
committed
Yield per-document RoPE position IDs from HuggingFaceTextDataset
Add a position buffer that tracks per-document RoPE positions, resetting at each document boundary. These positions are yielded alongside input tokens and used when block_causal attention is configured. Also add is_packed validation to catch misconfigured attention backends at trainer init time: packed dataloaders require flex or varlen with block_causal to prevent cross-document attention leakage.
1 parent 0691f51 commit 26678f6

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def test_c4_resumption(self):
5555
assert torch.equal(
5656
input_ids["input"], expected_input_ids["input"]
5757
)
58+
assert torch.equal(
59+
input_ids["positions"],
60+
expected_input_ids["positions"],
61+
)
5862
assert torch.equal(labels, expected_labels)
5963

6064
def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):

torchtitan/hf_datasets/text_datasets.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
# Variables for checkpointing
9797
self._sample_idx = 0
9898
self._token_buffer: list[int] = []
99+
self._position_buffer: list[int] = []
99100

100101
def _get_data_iter(self):
101102
# For map-style datasets, resume by skipping to the correct index
@@ -119,15 +120,19 @@ def __iter__(self):
119120
sample_text, add_bos=True, add_eos=True
120121
)
121122
self._token_buffer.extend(sample_tokens)
123+
self._position_buffer.extend(range(len(sample_tokens)))
122124
self._sample_idx += 1
123125

124126
while len(self._token_buffer) >= max_buffer_token_len:
125127
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
126-
# update tokens to the remaining tokens
128+
pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len])
129+
# update buffers to the remaining tokens
127130
self._token_buffer = self._token_buffer[max_buffer_token_len:]
131+
self._position_buffer = self._position_buffer[max_buffer_token_len:]
128132
input = x[:-1]
129133
label = x[1:]
130-
yield {"input": input}, label
134+
positions = pos[:-1]
135+
yield {"input": input, "positions": positions}, label
131136

132137
if not self.infinite:
133138
logger.warning(f"Dataset {self.dataset_name} has run out of data")
@@ -145,6 +150,7 @@ def __iter__(self):
145150

146151
def load_state_dict(self, state_dict):
147152
self._token_buffer = state_dict["token_buffer"]
153+
self._position_buffer = state_dict.get("position_buffer", [])
148154

149155
if isinstance(self._data, Dataset):
150156
self._sample_idx = state_dict["sample_idx"]
@@ -153,7 +159,10 @@ def load_state_dict(self, state_dict):
153159
self._data.load_state_dict(state_dict["data"])
154160

155161
def state_dict(self):
156-
_state_dict: dict[str, Any] = {"token_buffer": self._token_buffer}
162+
_state_dict: dict[str, Any] = {
163+
"token_buffer": self._token_buffer,
164+
"position_buffer": self._position_buffer,
165+
}
157166

158167
if isinstance(self._data, Dataset):
159168
_state_dict["sample_idx"] = self._sample_idx

torchtitan/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,16 @@ def post_dataloading_process(
596596
# extra_kwargs are.
597597
extra_kwargs: dict[str, Any] = {}
598598

599-
# TODO: improve the logic on obtaining attention masks
599+
# TODO: remove this guard once RoPE handles DTensor+positions.
600+
# The positions!=None path in RoPE uses torch.gather which fails
601+
# with DTensor+FSDP. For now, only pass positions through when
602+
# using flex/varlen + block_causal (where it's needed and works).
600603
layer = getattr(self.model_config, "layer", None)
601604
attn_config = getattr(layer, "attention", None) if layer else None
605+
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
606+
if attn_mask_type != "block_causal":
607+
extra_inputs.pop("positions", None)
608+
602609
attn_backend = getattr(attn_config, "attn_backend", "sdpa")
603610
if attn_backend in ["flex", "varlen"]:
604611
assert (

0 commit comments

Comments
 (0)