diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a2247aa21..36b0f139d 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -182,6 +182,9 @@ class Training: loaded from this path instead of downloaded. """ + synthetic_data: bool = False + """Use synthetic data""" + local_batch_size: int = 8 """Local batch size (i.e., per-device batch size)""" diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 0e30f8fe5..811ae66c8 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -7,20 +7,53 @@ from dataclasses import dataclass from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Iterator, Tuple, Dict import torch from datasets import Dataset, load_dataset from datasets.distributed import split_dataset_by_node from torch.distributed.checkpoint.stateful import Stateful -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, DataLoader from torchtitan.components.dataloader import ParallelAwareDataloader from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig from torchtitan.tools.logging import logger +class SyntheticIterableDataset(IterableDataset): + def __init__( + self, + batch_size: int, + seq_len: int, + vocab_size: int, + device: torch.device, + dtype: torch.dtype = torch.long, + ): + self.batch_size = batch_size + self.seq_len = seq_len + self.vocab_size = vocab_size + self.device = device + self.dtype = dtype + + def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: + while True: + inputs = torch.randint( + low=0, + high=self.vocab_size, + size=(self.seq_len,), + device=self.device, + dtype=self.dtype, + ) + labels = torch.randint( + low=0, + high=self.vocab_size, + size=(self.seq_len,), + device=self.device, + dtype=self.dtype, + ) + yield {"input": inputs}, labels + def _load_c4_dataset(dataset_path: str, split: str): """Load C4 dataset with default configuration.""" @@ -185,7 +218,35 @@ def build_hf_dataloader( dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len - + # If synthetic_data is True, return a synthetic dataloader: + device = torch.device(f"cuda:{dp_rank}") if torch.cuda.is_available() else torch.device("cpu") + if getattr(job_config.training, "synthetic_data", False): + # Use tokenizer.vocab_size if available; otherwise fallback + if tokenizer is not None and hasattr(tokenizer, "vocab_size"): + vocab_sz = tokenizer.vocab_size + else: + vocab_sz = getattr(job_config.model, "vocab_size", 30000) + + synthetic_ds = SyntheticIterableDataset( + seq_len=seq_len, + batch_size=batch_size, + vocab_size=vocab_sz, + device=device, + dtype=torch.long, + ) + logger.warning( + f"Using SYNTHETIC data → batch_size={batch_size}, seq_len={seq_len}, vocab_size={vocab_sz}" + ) + return ParallelAwareDataloader( + dataset=synthetic_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) + + # Otherwise, build the normal HuggingFaceDataset pipeline: + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path hf_ds = HuggingFaceDataset( dataset_name=dataset_name, dataset_path=dataset_path,