Skip to content

Commit a8899e4

Browse files
authored
[Refactor] Rename datasets to prepare for multimodal datasets (#1916)
As titled, made the following change of names: - torchtitan.datasets -> torchtitan.hf_datasets - torchtitan.datasets.hf_datasets -> torchtitan.hf_datasets.text_datasets - build_hf_datasets -> build_text_datasets - build_hf_validation_datasets -> build_text_validation_datasets
1 parent 29624e3 commit a8899e4

File tree

16 files changed

+36
-36
lines changed

16 files changed

+36
-36
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from datasets import load_dataset
1111
from torchtitan.components.tokenizer import HuggingFaceTokenizer
1212
from torchtitan.config import ConfigManager
13-
from torchtitan.datasets import DatasetConfig
14-
from torchtitan.datasets.hf_datasets import build_hf_dataloader, DATASETS
13+
from torchtitan.hf_datasets import DatasetConfig
14+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader, DATASETS
1515

1616

1717
class TestDatasetCheckpointing(unittest.TestCase):
@@ -72,7 +72,7 @@ def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank)
7272
]
7373
)
7474

75-
return build_hf_dataloader(
75+
return build_text_dataloader(
7676
tokenizer=tokenizer,
7777
dp_world_size=world_size,
7878
dp_rank=rank,

tests/unit_tests/test_train_spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from torchtitan.components.optimizer import build_optimizers, OptimizersContainer
1616
from torchtitan.components.tokenizer import build_hf_tokenizer
1717
from torchtitan.config import Optimizer as OptimizerConfig
18-
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1918
from torchtitan.distributed.parallel_dims import ParallelDims
19+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
2020
from torchtitan.models.llama3 import parallelize_llama
2121
from torchtitan.protocols import BaseModelArgs, ModelProtocol
2222
from torchtitan.protocols.train_spec import (
@@ -82,7 +82,7 @@ def test_register_train_spec(self):
8282
pipelining_fn=None,
8383
build_optimizers_fn=build_optimizers,
8484
build_lr_schedulers_fn=build_lr_schedulers,
85-
build_dataloader_fn=build_hf_dataloader,
85+
build_dataloader_fn=build_text_dataloader,
8686
build_tokenizer_fn=build_hf_tokenizer,
8787
build_loss_fn=build_cross_entropy_loss,
8888
)
@@ -103,7 +103,7 @@ def test_optim_hook(self):
103103
pipelining_fn=None,
104104
build_optimizers_fn=fake_build_optimizers_with_hook,
105105
build_lr_schedulers_fn=build_lr_schedulers,
106-
build_dataloader_fn=build_hf_dataloader,
106+
build_dataloader_fn=build_text_dataloader,
107107
build_tokenizer_fn=build_hf_tokenizer,
108108
build_loss_fn=build_cross_entropy_loss,
109109
)

torchtitan/components/validate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from torchtitan.components.metrics import MetricsProcessor
1515
from torchtitan.components.tokenizer import BaseTokenizer
1616
from torchtitan.config import JobConfig
17-
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader
1817
from torchtitan.distributed import ParallelDims, utils as dist_utils
18+
from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader
1919
from torchtitan.tools import utils
2020
from torchtitan.tools.logging import logger
2121

@@ -62,7 +62,7 @@ def __init__(
6262
self.job_config = job_config
6363
self.parallel_dims = parallel_dims
6464
self.loss_fn = loss_fn
65-
self.validation_dataloader = build_hf_validation_dataloader(
65+
self.validation_dataloader = build_text_validation_dataloader(
6666
job_config=job_config,
6767
dp_world_size=dp_world_size,
6868
dp_rank=dp_rank,

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
from torchtitan.components.tokenizer import BaseTokenizer
2424
from torchtitan.config import JobConfig
25-
from torchtitan.datasets import DatasetConfig
2625
from torchtitan.experiments.flux.dataset.tokenizer import (
2726
build_flux_tokenizer,
2827
FluxTokenizer,
2928
)
29+
from torchtitan.hf_datasets import DatasetConfig
3030
from torchtitan.tools.logging import logger
3131

3232

torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from datasets import load_dataset
1212

1313
from torchtitan.config import ConfigManager
14-
from torchtitan.datasets import DatasetConfig
1514
from torchtitan.experiments.flux.dataset.flux_dataset import (
1615
_cc12m_wds_data_processor,
1716
build_flux_dataloader,
1817
DATASETS,
1918
)
19+
from torchtitan.hf_datasets import DatasetConfig
2020

2121

2222
class TestFluxDataLoader(unittest.TestCase):

torchtitan/experiments/forge/example_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from torchtitan.components.tokenizer import build_hf_tokenizer
1919
from torchtitan.components.validate import build_validator
2020
from torchtitan.config import ConfigManager, JobConfig
21-
from torchtitan.datasets.hf_datasets import build_hf_dataloader
2221
from torchtitan.distributed import utils as dist_utils
22+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
2323
from torchtitan.tools import utils
2424
from torchtitan.tools.logging import init_logger, logger
2525
from torchtitan.tools.profiling import (
@@ -57,7 +57,7 @@ def __init__(self, job_config: JobConfig):
5757
self.tokenizer = build_hf_tokenizer(job_config)
5858

5959
# build dataloader
60-
self.dataloader = build_hf_dataloader(
60+
self.dataloader = build_text_dataloader(
6161
dp_world_size=self.dp_degree,
6262
dp_rank=self.dp_rank,
6363
tokenizer=self.tokenizer,

torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from torchtitan.components.lr_scheduler import build_lr_schedulers
99
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
11-
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1211
from torchtitan.distributed.pipeline_parallel import pipeline_llm
12+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
1313
from torchtitan.models.deepseek_v3 import deepseekv3_args
1414
from torchtitan.protocols.train_spec import TrainSpec
1515

@@ -25,7 +25,7 @@ def get_train_spec() -> TrainSpec:
2525
pipelining_fn=pipeline_llm,
2626
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
2727
build_lr_schedulers_fn=build_lr_schedulers,
28-
build_dataloader_fn=build_hf_dataloader,
28+
build_dataloader_fn=build_text_dataloader,
2929
build_tokenizer_fn=build_hf_tokenizer,
3030
build_loss_fn=build_cross_entropy_loss,
3131
)

torchtitan/experiments/simple_fsdp/llama3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from torchtitan.components.lr_scheduler import build_lr_schedulers
99
from torchtitan.components.optimizer import build_optimizers
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
11-
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1211
from torchtitan.distributed.pipeline_parallel import pipeline_llm
12+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
1313
from torchtitan.models.llama3 import llama3_args
1414
from torchtitan.protocols.train_spec import TrainSpec
1515

@@ -25,7 +25,7 @@ def get_train_spec() -> TrainSpec:
2525
pipelining_fn=pipeline_llm,
2626
build_optimizers_fn=build_optimizers,
2727
build_lr_schedulers_fn=build_lr_schedulers,
28-
build_dataloader_fn=build_hf_dataloader,
28+
build_dataloader_fn=build_text_dataloader,
2929
build_tokenizer_fn=build_hf_tokenizer,
3030
build_loss_fn=build_cross_entropy_loss,
3131
)

torchtitan/experiments/vlm/datasets/mm_datasets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torchtitan.components.dataloader import ParallelAwareDataloader
2323
from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer
2424
from torchtitan.config import JobConfig
25-
from torchtitan.datasets import DatasetConfig
25+
from torchtitan.hf_datasets import DatasetConfig
2626
from torchtitan.tools.logging import logger
2727

2828
from ..model.args import SpecialTokens
@@ -226,8 +226,8 @@ def _validate_mm_dataset(
226226
return path, config.loader, config.sample_processor
227227

228228

229-
class MultiModalDataset(IterableDataset, Stateful):
230-
"""MultiModal Dataset with support for sample packing."""
229+
class HuggingFaceMultiModalDataset(IterableDataset, Stateful):
230+
"""HuggingFace MultiModal Dataset with support for sample packing."""
231231

232232
def __init__(
233233
self,
@@ -403,7 +403,7 @@ def build_mm_dataloader(
403403
packing_buffer_size = job_config.data.packing_buffer_size
404404
special_tokens = SpecialTokens.from_tokenizer(tokenizer)
405405

406-
dataset = MultiModalDataset(
406+
dataset = HuggingFaceMultiModalDataset(
407407
dataset_name=job_config.training.dataset,
408408
dataset_path=dataset_path,
409409
tokenizer=tokenizer,

0 commit comments

Comments
 (0)