Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ class Training:
loaded from this path instead of downloaded.
"""

synthetic_data: bool = False
"""Use synthetic data"""

local_batch_size: int = 8
"""Local batch size (i.e., per-device batch size)"""

Expand Down
67 changes: 64 additions & 3 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,53 @@
from dataclasses import dataclass

from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Iterator, Tuple, Dict

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torch.utils.data import IterableDataset, DataLoader

from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.config import JobConfig
from torchtitan.tools.logging import logger

class SyntheticIterableDataset(IterableDataset):
def __init__(
self,
batch_size: int,
seq_len: int,
vocab_size: int,
device: torch.device,
dtype: torch.dtype = torch.long,
):
self.batch_size = batch_size
self.seq_len = seq_len
self.vocab_size = vocab_size
self.device = device
self.dtype = dtype

def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
while True:
inputs = torch.randint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fake data, not "synthetic" data.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how about call it random data? The goal is to remove dataset dependency for quick performance benchmarking.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default c4 will have two problems:

  1. Although small enough, but still have dependency for user to download before testing, and in rapid debugging and reruns, it is possible to hit HF request limit. Other case is in an unstable network, also affecting smooth development. I had to make local changes like this so I can developing without worry about dataset. Guess many users also has similar experience.
  2. With larger models and bigger batch size runs, it will easily loopback data, but same reason as 1, it may be limited or time consuming in many cases for user to download very large dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random dataset is usually very useful when debugging the CPU overhead brought by data loading, though I'm not sure if we already have such a use case. Multimodal may be benefit from random dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alfuyao1986

Although small enough, but still have dependency for user to download before testing, and in rapid debugging and reruns, it is possible to hit HF request limit. Other case is in an unstable network, also affecting smooth development. I had to make local changes like this so I can developing without worry about dataset. Guess many users also has similar experience.

We have c4_test stored in the repo
https://github.com/pytorch/torchtitan/tree/main/tests/assets/c4_test

With larger models and bigger batch size runs, it will easily loopback data, but same reason as 1, it may be limited or time consuming in many cases for user to download very large dataset.

What would be the advantage of using random / fake data versus looping back on c4_test?

@fegin

Multimodal may be benefit from random dataset.

As we don't have multimodal training, I think the main thing I'd like to understand what's the benefit of adding random data on top of existing c4_test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random dataset generally can skip the overhead of data loading, like actually reading from a disk. This is not related to whether the dataset is large or small. But as mentioned above, this may be more useful when we start to see dataloader overhead is a big thing. As for development efficiency, I didn't encounter such an issue, so I should not be the one to answer.

This is solely my opinion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, given c4_test is already pre-stored in the repo, for most of the cases, it should be fine. I am actually completely fine with using pre-stored c4_test dataset. Only two more consideration just bring up for discussion.

  1. Random dataset can usually stress the whole stack better, numerically and computationally, vs. a small repeated dataset, but it is debatable that whether this additional stress practically realistic and necessary.
  2. Other frameworks (MaxText, Megatron-LM) do provide "synthetic/mock" data options for fast benchmarking, for ease of comparison point of view, it may be better to have a matching option.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the context!

From my perspective, the value of this dataset is somewhat limited, given we already have c4_test which doesn't involve randomness so has become a standard way for numerical testing even when parallelism / world size changes.

That said, if people have strong opinion to add this dataset, I'm OK, too. If that's the case, I would suggest making a new builder function & file, instead of piggyback on existing build_hf_dataloader. I understand that would make it harder to switch to this new dataset from config, but that's not a good reason to reuse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we definitely shouldn't use build_hf_dataloader for random dataset. There is actually another benefit of random dataset (when it has a deterministic option) -- debugging checkpoint issue. Given that the dataloader is controlled by other package, having a random dataset with a deterministic option will make debugging checkpoint inconsistency easier, at least we can rule out the dataset/dataloader problem.

low=0,
high=self.vocab_size,
size=(self.seq_len,),
device=self.device,
dtype=self.dtype,
)
labels = torch.randint(
low=0,
high=self.vocab_size,
size=(self.seq_len,),
device=self.device,
dtype=self.dtype,
)
yield {"input": inputs}, labels


def _load_c4_dataset(dataset_path: str, split: str):
"""Load C4 dataset with default configuration."""
Expand Down Expand Up @@ -185,7 +218,35 @@ def build_hf_dataloader(
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.local_batch_size
seq_len = job_config.training.seq_len

# If synthetic_data is True, return a synthetic dataloader:
device = torch.device(f"cuda:{dp_rank}") if torch.cuda.is_available() else torch.device("cpu")
if getattr(job_config.training, "synthetic_data", False):
# Use tokenizer.vocab_size if available; otherwise fallback
if tokenizer is not None and hasattr(tokenizer, "vocab_size"):
vocab_sz = tokenizer.vocab_size
else:
vocab_sz = getattr(job_config.model, "vocab_size", 30000)

synthetic_ds = SyntheticIterableDataset(
seq_len=seq_len,
batch_size=batch_size,
vocab_size=vocab_sz,
device=device,
dtype=torch.long,
)
logger.warning(
f"Using SYNTHETIC data → batch_size={batch_size}, seq_len={seq_len}, vocab_size={vocab_sz}"
)
return ParallelAwareDataloader(
dataset=synthetic_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)

# Otherwise, build the normal HuggingFaceDataset pipeline:
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
hf_ds = HuggingFaceDataset(
dataset_name=dataset_name,
dataset_path=dataset_path,
Expand Down