Skip to content

Commit 0b0f7d7

Browse files
committed
Yield per-document RoPE position IDs from HuggingFaceTextDataset
Add a position buffer that tracks per-document RoPE positions, resetting at each document boundary. These positions are yielded alongside input tokens and used when block_causal attention is configured. Also add is_packed validation to catch misconfigured attention backends at trainer init time: packed dataloaders require flex or varlen with block_causal to prevent cross-document attention leakage.
1 parent 0691f51 commit 0b0f7d7

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def test_c4_resumption(self):
5555
assert torch.equal(
5656
input_ids["input"], expected_input_ids["input"]
5757
)
58+
assert torch.equal(
59+
input_ids["positions"],
60+
expected_input_ids["positions"],
61+
)
5862
assert torch.equal(labels, expected_labels)
5963

6064
def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):

torchtitan/components/dataloader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ class Config(Configurable.Config):
4747
dataset: str = ""
4848
dataset_path: str | None = None
4949

50+
@property
51+
def is_packed(self) -> bool:
52+
"""Whether the underlying dataset packs multiple documents per sequence."""
53+
return getattr(self.dataset, "is_packed", False)
54+
5055
@abstractmethod
5156
def __iter__(self) -> Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]]:
5257
...

torchtitan/hf_datasets/text_datasets.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def _validate_dataset(
6868

6969

7070
class HuggingFaceTextDataset(IterableDataset, Stateful):
71+
is_packed: bool = True
72+
7173
def __init__(
7274
self,
7375
dataset_name: str,
@@ -96,6 +98,7 @@ def __init__(
9698
# Variables for checkpointing
9799
self._sample_idx = 0
98100
self._token_buffer: list[int] = []
101+
self._position_buffer: list[int] = []
99102

100103
def _get_data_iter(self):
101104
# For map-style datasets, resume by skipping to the correct index
@@ -119,15 +122,19 @@ def __iter__(self):
119122
sample_text, add_bos=True, add_eos=True
120123
)
121124
self._token_buffer.extend(sample_tokens)
125+
self._position_buffer.extend(range(len(sample_tokens)))
122126
self._sample_idx += 1
123127

124128
while len(self._token_buffer) >= max_buffer_token_len:
125129
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
126-
# update tokens to the remaining tokens
130+
pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len])
131+
# update buffers to the remaining tokens
127132
self._token_buffer = self._token_buffer[max_buffer_token_len:]
133+
self._position_buffer = self._position_buffer[max_buffer_token_len:]
128134
input = x[:-1]
129135
label = x[1:]
130-
yield {"input": input}, label
136+
positions = pos[:-1]
137+
yield {"input": input, "positions": positions}, label
131138

132139
if not self.infinite:
133140
logger.warning(f"Dataset {self.dataset_name} has run out of data")
@@ -145,6 +152,7 @@ def __iter__(self):
145152

146153
def load_state_dict(self, state_dict):
147154
self._token_buffer = state_dict["token_buffer"]
155+
self._position_buffer = state_dict.get("position_buffer", [])
148156

149157
if isinstance(self._data, Dataset):
150158
self._sample_idx = state_dict["sample_idx"]
@@ -153,7 +161,10 @@ def load_state_dict(self, state_dict):
153161
self._data.load_state_dict(state_dict["data"])
154162

155163
def state_dict(self):
156-
_state_dict: dict[str, Any] = {"token_buffer": self._token_buffer}
164+
_state_dict: dict[str, Any] = {
165+
"token_buffer": self._token_buffer,
166+
"position_buffer": self._position_buffer,
167+
}
157168

158169
if isinstance(self._data, Dataset):
159170
_state_dict["sample_idx"] = self._sample_idx
@@ -168,8 +179,10 @@ def state_dict(self):
168179
class HuggingFaceTextDataLoader(ParallelAwareDataloader):
169180
"""Configurable text dataloader that wraps HuggingFaceTextDataset.
170181
171-
This dataloader can be used for both training and validation by
172-
configuring the appropriate dataset, seq_len, batch_size, etc.
182+
This dataloader packs multiple documents into each sequence by
183+
concatenating tokenized documents into a continuous stream and
184+
slicing fixed-size chunks. Use with block_causal attention to
185+
prevent cross-document attention leakage.
173186
"""
174187

175188
@dataclass(kw_only=True, slots=True)

torchtitan/trainer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,20 @@ def __init__(self, config: Config):
255255
)
256256
self.model_config = model_config
257257

258+
# Validate that packed dataloaders use block_causal attention
259+
if self.dataloader.is_packed:
260+
attn_config = model_config.layer.attention
261+
if (
262+
attn_config.attn_backend == "sdpa"
263+
or attn_config.attn_mask_type != "block_causal"
264+
):
265+
raise ValueError(
266+
"Packed dataloader requires attn_backend='flex' or 'varlen' "
267+
"with attn_mask_type='block_causal' for document isolation. "
268+
f"Got attn_backend='{attn_config.attn_backend}', "
269+
f"attn_mask_type='{attn_config.attn_mask_type}'."
270+
)
271+
258272
logger.info(
259273
f"Building {model_spec.name} {model_spec.flavor} "
260274
f"with {json.dumps(model_config.to_dict(), indent=2, ensure_ascii=False)}"
@@ -597,8 +611,7 @@ def post_dataloading_process(
597611
extra_kwargs: dict[str, Any] = {}
598612

599613
# TODO: improve the logic on obtaining attention masks
600-
layer = getattr(self.model_config, "layer", None)
601-
attn_config = getattr(layer, "attention", None) if layer else None
614+
attn_config = self._get_attn_config()
602615
attn_backend = getattr(attn_config, "attn_backend", "sdpa")
603616
if attn_backend in ["flex", "varlen"]:
604617
assert (
@@ -851,6 +864,11 @@ def train(self):
851864

852865
logger.info("Training completed")
853866

867+
def _get_attn_config(self):
868+
"""Extract attention config from model config, or None if not available."""
869+
layer = getattr(self.model_config, "layer", None)
870+
return getattr(layer, "attention", None) if layer else None
871+
854872
def should_continue_training(self) -> bool:
855873
return self.step < self.config.training.steps
856874

0 commit comments

Comments
 (0)