Skip to content

Commit b3d7f60

Browse files
committed
Yield per-document RoPE position IDs from HuggingFaceTextDataset
Add a position buffer that tracks per-document RoPE positions, resetting at each document boundary. These positions are yielded alongside input tokens and used when block_causal attention is configured. Also add is_packed validation to catch misconfigured attention backends at trainer init time: packed dataloaders require flex or varlen with block_causal to prevent cross-document attention leakage.
1 parent 0691f51 commit b3d7f60

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def test_c4_resumption(self):
5555
assert torch.equal(
5656
input_ids["input"], expected_input_ids["input"]
5757
)
58+
assert torch.equal(
59+
input_ids["positions"],
60+
expected_input_ids["positions"],
61+
)
5862
assert torch.equal(labels, expected_labels)
5963

6064
def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):

torchtitan/hf_datasets/text_datasets.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
# Variables for checkpointing
9797
self._sample_idx = 0
9898
self._token_buffer: list[int] = []
99+
self._position_buffer: list[int] = []
99100

100101
def _get_data_iter(self):
101102
# For map-style datasets, resume by skipping to the correct index
@@ -119,15 +120,19 @@ def __iter__(self):
119120
sample_text, add_bos=True, add_eos=True
120121
)
121122
self._token_buffer.extend(sample_tokens)
123+
self._position_buffer.extend(range(len(sample_tokens)))
122124
self._sample_idx += 1
123125

124126
while len(self._token_buffer) >= max_buffer_token_len:
125127
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
126-
# update tokens to the remaining tokens
128+
pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len])
129+
# update buffers to the remaining tokens
127130
self._token_buffer = self._token_buffer[max_buffer_token_len:]
131+
self._position_buffer = self._position_buffer[max_buffer_token_len:]
128132
input = x[:-1]
129133
label = x[1:]
130-
yield {"input": input}, label
134+
positions = pos[:-1]
135+
yield {"input": input, "positions": positions}, label
131136

132137
if not self.infinite:
133138
logger.warning(f"Dataset {self.dataset_name} has run out of data")
@@ -145,6 +150,7 @@ def __iter__(self):
145150

146151
def load_state_dict(self, state_dict):
147152
self._token_buffer = state_dict["token_buffer"]
153+
self._position_buffer = state_dict.get("position_buffer", [])
148154

149155
if isinstance(self._data, Dataset):
150156
self._sample_idx = state_dict["sample_idx"]
@@ -153,7 +159,10 @@ def load_state_dict(self, state_dict):
153159
self._data.load_state_dict(state_dict["data"])
154160

155161
def state_dict(self):
156-
_state_dict: dict[str, Any] = {"token_buffer": self._token_buffer}
162+
_state_dict: dict[str, Any] = {
163+
"token_buffer": self._token_buffer,
164+
"position_buffer": self._position_buffer,
165+
}
157166

158167
if isinstance(self._data, Dataset):
159168
_state_dict["sample_idx"] = self._sample_idx

0 commit comments

Comments
 (0)