Skip to content

Commit 9810191

Browse files
authored
[BC Breaking] Config System Refactor: TOML to Python Dataclass Registry (#2386)
**NOTE**: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged. # author's note This refactor is mainly trying to address two issues: - bad encapsulation: previously a monolithic `JobConfig` is leaked everywhere - not easy to iterate and experiment on model architecture and training components The main changes are: - Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility. - Each `Configurable` component owns its own `Config`, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP. - This is partly inspired by repos like [AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's [ML API Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)), github issues (e.g. #1055), and offline discussions (with @Chillee, @ailzhang etc.). - Similar functionality can be alternatively achieved by other ways, e.g. `_target_` in [Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/), but there are opinions not to couple with Hydra's other offerings. See #1415 - Main entry point switches from TOML files to Python functions (a.k.a. `config_registry.py` in each model). - TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan. - That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power. - We still uses [tyro](https://github.com/brentyi/tyro) to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (`Trainer.Config` now, `JobConfig` in the past). - If CLI is not needed, new trainer (or any high-level) config is not required. - To support "polymorphic construction" from CLI without the hassle, check out [chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction). This PR also - updates the docs -- there might be remaining outdated docs, please raise issues or help fix - moves ft to experiments, continuing the effort in #2311 Remaining work - [AutoParallel CI failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386) seems caused by the way RoPE is authored, and needs change in autoparallel. (cc @xmfan) - being fixed in meta-pytorch/autoparallel#321 - [CompilerToolkit CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386) `TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'` cc @yiming0416 please help take a look - [SimpleFSDP CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386) is the same as #2312 around dynamic shape for for-loop MoE experts computation. (cc @pianpwk) - being fixed in #2399 - Fix integration scripts for MAST, Zoomer, etc. - organize docs from `docs/` to subfolders, as we are having more contents to cover in general - generate and store serialized configs (maybe not in the repo) - continue SAC refactor in #2357, but somehow keep the every-other-mm policy (cc @mori360) - refactor RoPE in general, at least resolving the following TODOs in code (cc @shuhuayu) - having to set / no validation on rope dim == decoder dim // attention n_heads - consolidate `apply_rotary_emb_complex` and `apply_rotary_emb_single_complex` - address #2417 Longer-term issues - More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang) - ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate `model.py` and `parallelize.py` by - sharing AC, compile, DP application across all Decoder models - putting per-module TP/CP/EP sharding plan inside model itself - Right now `BaseModel.update_from_config` violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer. - Refactor `init_weights` into `Module.Config` instead of staying in `Module` - The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init. - This may require refactor of current TransformerBlock and its config. E.g. `weight_init_std` may need to be put in config, with `__post_init__` determining its value. (See related complaints / discussions on `__post_init__` by [chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md)) Note to reviewer: Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR. # claude-generated summary ## Summary This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic `JobConfig` parser with **typed Python dataclass configs**, a **`Configurable` base class** pattern, and a **`config_registry`** module per model. **270 files changed, 10,025 insertions, 11,418 deletions.** --- ## Motivation The previous system used TOML files parsed by a custom `ConfigManager` that layered CLI overrides on top. While simple, this had several friction points: 1. **No type safety at config boundaries.** TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., `training.stpes`) silently becomes a default value. 4. **Flat namespace.** All config sections (`[model]`, `[training]`, `[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class. Every component received the full `JobConfig` even when it only needed a few fields. 5. **Experiment extension was ad-hoc.** Experiments that needed custom config fields (e.g., SimpleFSDP's `compile.graph_passes` or FaultTolerant's `fault_tolerance.*`) required a `custom_config_module` TOML key and a runtime `_merge_configs` call to graft new fields onto `JobConfig`. 6. **Model args were disconnected from model code.** A `ModelArgs` dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that bundled model + parallelization + loss was registered separately, with no type-level link between them. --- ## What Changed ### 1. `Configurable` Base Class A new `Configurable` base class (`torchtitan/config/configurable.py`) establishes a universal pattern: ```python class Configurable: @DataClass(kw_only=True, slots=True) class Config: def build(self, **kwargs): return self._owner(config=self, **kwargs) def __init_subclass__(cls, **kwargs): # Auto-wires Config.build() -> cls(config=..., **kwargs) # Enforces @DataClass(kw_only=True, slots=True) on every Config ``` Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling `config.build()` constructs the owning class. ### 2. `Trainer.Config` Replaces `JobConfig` The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested dataclass that aggregates typed sub-configs: ```python class Trainer(Stateful, Configurable): @DataClass(kw_only=True, slots=True) class Config(Configurable.Config): model_spec: ModelSpec | None = None # set by config_registry, suppressed from CLI job: JobConfig = ... training: TrainingConfig = ... parallelism: ParallelismConfig = ... optimizer: OptimizersContainer.Config = ... lr_scheduler: LRSchedulersContainer.Config = ... checkpoint: CheckpointManager.Config = ... dataloader: BaseDataLoader.Config = ... metrics: MetricsProcessor.Config = ... # ... etc. ``` Each sub-config is the `Config` class of the component that consumes it (e.g., `CheckpointManager.Config` is defined inside `CheckpointManager`). Components receive only their own config, not the entire training config. ### 3. `config_registry.py` Replaces TOML Files Each model defines a `config_registry.py` with functions that return complete `Trainer.Config` instances: ```python # torchtitan/models/llama3/config_registry.py def llama3_debugmodel() -> Trainer.Config: return Trainer.Config( job=JobConfig(description="Llama 3 debug training", ...), model_spec=model_registry("debugmodel"), optimizer=OptimizersContainer.Config(lr=8e-4), training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10), dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"), # ... ) def llama3_debugmodel_float8() -> Trainer.Config: config = llama3_debugmodel() config.model_converters = ModelConvertersContainer.Config( converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)] ) return config ``` ### 4. `TrainSpec` -> `ModelSpec` `TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`. ### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy Model hyperparameters move from a flat `ModelArgs` dataclass into a nested `Config` hierarchy that mirrors the module tree: ```python # Before (main): flat args.py @DataClass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 # ... 20+ flat fields # After (this PR): nested Config in model class class Llama3Model(Decoder): @DataClass(kw_only=True, slots=True) class Config(Decoder.Config): layer: Llama3TransformerBlock.Config # contains attention + FFN configs rope: RoPE.Config # contains RoPE-specific params ``` ### 6. `train.py` Split The monolithic `train.py` (~800 lines) is split into: - `train.py` (~60 lines): thin entry point that calls `ConfigManager.parse_args()` and `config.build()` - `trainer.py` (~850 lines): the `Trainer` class with training loop logic ### 7. Experiment Extension via Inheritance Experiments extend the config system through dataclass subclassing instead of runtime config merging: ```python # torchtitan/experiments/simple_fsdp/configs.py @DataClass(kw_only=True, slots=True) class SimpleFSDPConfig(Trainer.Config): compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig) ``` Their `config_registry.py` returns the subclassed config type, and `tyro` auto-generates CLI parsing for the extended fields. --- ## UX Comparison ### Launching Training ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh # After (this PR) MODEL=llama3 CONFIG=llama3_8b ./run_train.sh ``` ### CLI Overrides ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ --training.steps 100 --parallelism.tensor_parallel_degree 2 # After (this PR) ./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2 # (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh) ``` CLI override syntax is unchanged (`--section.field value`), but `tyro` now provides typed `--help` output generated from the dataclass tree. ### Defining a New Model Config ```bash # Before: create a new TOML file, copy-paste sections, edit values cp train_configs/debug_model.toml train_configs/my_experiment.toml vim train_configs/my_experiment.toml # After: write a Python function that mutates an existing config def my_experiment() -> Trainer.Config: config = llama3_debugmodel() config.training.steps = 100 config.optimizer.lr = 1e-4 return config ``` ### Adding Experiment-Specific Config Fields ```python # Before (main): custom_config_module in TOML + runtime _merge_configs # Requires: TOML key pointing to a Python module, dynamic dataclass creation # After (this PR): dataclass inheritance @DataClass(kw_only=True, slots=True) class MyExperimentConfig(Trainer.Config): my_custom_field: str = "default" ``` ### Float8 / Quantization Configuration ```python # Before (main): TOML section # [quantize.linear.float8] # enable_fsdp_float8_all_gather = true # precompute_float8_dynamic_scale_for_fsdp = true # After (this PR): typed config object model_converters=ModelConvertersContainer.Config( converters=[ Float8LinearConverter.Config( enable_fsdp_float8_all_gather=True, precompute_float8_dynamic_scale_for_fsdp=True, ), ], ), ``` --- ## Limitations and Trade-offs ### 1. Configs are no longer declarative text files TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (`--training.steps 100`) works the same, but understanding the full config requires reading Python. ### 2. Steeper learning curve for contributors Adding a new model now requires understanding the `Configurable` protocol, nested `Config` dataclass hierarchy, and the `config_registry` pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry. ### 3. Config serialization is more complex TOML files were trivially serializable and diffable. The new system supports `to_dict()` + JSON serialization, but configs containing callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully round-tripped. The `model_spec` field is excluded from serialization and suppressed from CLI parsing. ### 4. tyro dependency The CLI parsing now depends on `tyro`, a third-party library. While `tyro` is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here. ### 5. `@dataclass(slots=True)` constraints The `Configurable` base class enforces `@dataclass(kw_only=True, slots=True)` on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, `slots=True` prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the `@dataclass(kw_only=True, slots=True)` decorator. ### 6. Two-level indirection for model selection The old system required one identifier: `--job.config_file path/to/file.toml`. The new system requires two: `--module llama3 --config llama3_8b`. While this separates model identity from training recipe, it adds an extra argument. --- ## Numerics Verification All model configs were verified for numerical equivalence against the main branch (commit `10d8a306`): NOTE - only models that can fit on 8 GPUs are tested - only subset of parallelism combination are tested | Model | Status | Notes | |-------|--------|-------| | llama3 (debugmodel, 8B) | Bitwise match | | | llama3 (debugmodel_flex_attn) | Bitwise match | | | qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | | | deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) | Pre-existing main branch bug: missing `eps` in final RMSNorm | | llama4 debugmodel | Bitwise match | _irope variants don't work on main (FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work after this PR | | **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?) | | | flux | Bitwise match | | --- ## Migration Guide | Old (main) | New (this PR) | |---|---| | `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3 CONFIG=llama3_8b ./run_train.sh` | | `--job.config_file path.toml` | `--module llama3 --config llama3_8b` | | `train_configs/*.toml` | `config_registry.py` functions | | `TrainSpec` | `ModelSpec` | | `ModelArgs` / `args.py` | Nested `Model.Config` dataclass | | `custom_config_module` + `_merge_configs()` | Subclass `Trainer.Config` | | `build_model_converters()` free function | `ModelConvertersContainer.Config.build()` | | `build_metrics_processor()` free function | `MetricsProcessor.Config.build()` |
1 parent 238fa99 commit 9810191

File tree

280 files changed

+10502
-12043
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

280 files changed

+10502
-12043
lines changed

.ci/docker/requirements-dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ pytest==7.3.2
33
pytest-cov
44
pre-commit
55
pyrefly==0.45.1
6-
tomli-w >= 1.1.0
76
transformers

.ci/docker/requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
torchdata >= 0.8.0
22
datasets >= 3.6.0
3-
tomli >= 1.1.0 ; python_version < "3.11"
43
tensorboard
5-
tabulate
64
wandb
75
fsspec
86
tyro

.github/workflows/integration_test_8gpu_torchft.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ on:
66
tags:
77
- ciflow/8gpu/*
88
paths:
9-
- 'torchtitan/components/ft.py'
9+
- 'torchtitan/experiments/ft/**'
10+
- 'torchtitan/components/checkpoint.py'
1011
- '.github/workflows/integration_test_8gpu_torchft.yaml'
1112
pull_request:
1213
paths:
13-
- 'torchtitan/components/ft.py'
14+
- 'torchtitan/experiments/ft/**'
15+
- 'torchtitan/components/checkpoint.py'
1416
- '.github/workflows/integration_test_8gpu_torchft.yaml'
1517
schedule:
1618
# Runs every 6 hours
@@ -71,6 +73,6 @@ jobs:
7173
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 &
7274
echo "ft_integration_test"
7375
# Getting error - Cuda failure 217 'peer access is not supported between these two devices'
74-
python -m tests.integration_tests.ft $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8
76+
python -m torchtitan.experiments.ft.tests.integration_tests $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8
7577
# pkill -9 torchft_lighthouse
7678
rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Note: To accelerate contributions to and innovations around `torchtitan`, we are
5151
- After the model change, it should still load the original checkpoint correctly.
5252
- Document the reasons for the code change, similar to [composability.md](docs/composability.md).
5353
- Keep code modularized, especially for [train.py](train.py), so that it remains easy to copy-paste into a minimal code example. If necessary:
54-
- Introduce new config options/category in [job_config.py](torchtitan/config/job_config.py).
54+
- Introduce new config options/category in [configs.py](torchtitan/config/configs.py).
5555
- Create separate functions/files.
5656

5757
### Proof of Value
@@ -75,7 +75,7 @@ When appropriate, one should consider
7575

7676
- Adding CPU/GPU unit/integration tests.
7777
- To add a unit test, put it in the [tests](tests/) folder and follow the existing test files.
78-
- To add a GPU integration test, create a new `OverrideDefinitions` in [integration_tests.py](tests/integration_tests.py). It will override the default config to run on the Llama 3 [debug model](torchtitan/models/llama/train_configs/debug_model.toml).
78+
- To add a GPU integration test, create a new `OverrideDefinitions` in [integration_tests](tests/integration_tests/). It will override the default config to run on the Llama 3 debug model (see [config_registry.py](torchtitan/models/llama3/config_registry.py)).
7979
- Updating [README](README.md) and writing a new note in the [docs](docs/) folder on installation and usage, similar to [float8.md](docs/float8.md).
8080
- Updating [performance.md](docs/performance.md) with new performance results.
8181
- Creating GitHub issues for things that cannot be addressed at the moment.

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ We look forward to your contributions!
6868
7. DDP and HSDP
6969
8. [TorchFT](https://github.com/pytorch/torchft) integration
7070
9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
71-
10. Gradient accumulation, enabled by giving an additional `--training.global_batch_size` argument in configuration
71+
10. Gradient accumulation, enabled by giving an additional `--training.global_batch_size` argument on the CLI
7272
11. Flexible learning rate scheduler (warmup-stable-decay)
7373
12. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
7474
13. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
75-
14. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/)
75+
14. All options easily configured via [Python config registry](torchtitan/models/llama3/config_registry.py) with `--module` and `--config` CLI flags
7676
15. [Helper scripts](scripts/) to
7777
- download tokenizers from Hugging Face
7878
- convert original Llama 3 checkpoints into the expected DCP format
@@ -142,7 +142,7 @@ python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets
142142
Llama 3 8B model locally on 8 GPUs
143143

144144
```bash
145-
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
145+
MODULE=llama3 CONFIG=llama3_8b ./run_train.sh
146146
```
147147

148148
### Multi-Node Training

benchmarks/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ A submission should be a file / files including the following information
88
2. The model or theme of benchmarking, e.g. Llama 3.1, Async TP.
99
3. The hardware setup, including the types of GPUs, interconnections, etc.
1010
4. The actual performance report with training configs, e.g. via
11-
- `.toml` files / commandline arguments
12-
- complete configs, which can be found in the log with [`--print_config`](https://github.com/pytorch/torchtitan/blob/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc/torchtitan/config_manager.py#L47) turned on (preferred as the default value not shown in `.toml` or specified in commandline could change from time to time)
11+
- Python config files / commandline arguments
12+
- complete configs, which can be found in the log with [`--print_config`](https://github.com/pytorch/torchtitan/blob/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc/torchtitan/config_manager.py#L47) turned on (preferred as the default value not shown in config files or specified in commandline could change from time to time)
1313
5. The versions and date/time of `torchtitan`, `torch`, `torchao`, or any relevant dependencies.
1414
6. Other notes which could help reproduce the results.
1515

benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ Each host has
1616

1717
Runs were invoked with the following, where `NUM_NODES` was `4` and `8`.
1818

19-
**Warning**: the command here has been updated to use the latest version of torchtitan, which has had API changes since this benchmark was ran.
20-
To reproduce the results using the original torchtitan commit, change all instances of `quantize.linear.float8` to `float8` in the command below.
19+
**Warning**: the command below reflects the original invocation at the time of this benchmark. The torchtitan CLI has since changed to use `--module` and `--config` flags instead of `--job.config-file`. See the current [README](/README.md) for up-to-date usage.
2120
```
2221
torchrun \
2322
--nnodes $NUM_NODES \

docs/checkpoint.md

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,53 @@ You may want to enable checkpointing in `torchtitan` for better fault tolerance
55
## A general guide to use checkpoints during training
66

77
1. ENABLE CHECKPOINTING
8-
In your `torchtitan` training config, ensure that under `[checkpoint]`, `enable` is set to True.
9-
```
10-
[checkpoint]
11-
enable = true
12-
folder = "checkpoint"
13-
interval = 500
8+
In your config_registry function, configure the checkpoint settings:
9+
```python
10+
checkpoint=CheckpointManager.Config(
11+
interval=500,
12+
),
1413
```
14+
Or via CLI: `--checkpoint.interval 500`
15+
1516
2. SAVE MODEL ONLY
1617
By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
17-
```
18-
[checkpoint]
19-
enable = true
20-
last_save_model_only = true
18+
```python
19+
checkpoint=CheckpointManager.Config(
20+
interval=500,
21+
last_save_model_only=True,
22+
),
2123
```
2224

2325
3. CHOOSE DESIRED EXPORT PRECISION
2426
The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`.
25-
```
26-
[checkpoint]
27-
enable = true
28-
last_save_model_only = true
29-
export_dtype = "bfloat16"
27+
```python
28+
checkpoint=CheckpointManager.Config(
29+
interval=500,
30+
last_save_model_only=True,
31+
export_dtype="bfloat16",
32+
),
3033
```
3134

3235
4. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING
3336
In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading.
34-
This parameter takes a list of string that should be excluded from loading.
37+
```python
38+
checkpoint=CheckpointManager.Config(
39+
exclude_from_loading=["data_loader", "lr_scheduler"],
40+
),
3541
```
36-
[checkpoint]
37-
enable = true
38-
exclude_from_loading = ["data_loader", "lr_scheduler"]
39-
```
40-
When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`.
42+
When used in command line: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`.
4143

4244
5. EXAMPLE CHECKPOINT CONFIGURATION
43-
```
44-
[checkpoint]
45-
enable = true
46-
folder = "checkpoint"
47-
interval = 10
48-
load_step = 5
49-
last_save_model_only = true
50-
export_dtype = "bfloat16"
45+
```python
46+
checkpoint=CheckpointManager.Config(
47+
interval=10,
48+
load_step=5,
49+
last_save_model_only=True,
50+
export_dtype="bfloat16",
51+
),
5152
```
5253

53-
A more exhaustive and up-to-date list of checkpoint config options can be found in `torchtitan/config/job_config.py`
54+
A more exhaustive and up-to-date list of checkpoint config options can be found in `torchtitan/components/checkpoint.py` (`CheckpointManager.Config`).
5455

5556
## Creating a seed checkpoint
5657
Sometimes one needs to create a seed checkpoint to initialize a model from step 0.
@@ -60,15 +61,15 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
6061
To create a seed checkpoint, use the same model config as you use for training.
6162
e.g.
6263
```bash
63-
NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
64+
NGPU=1 ./run_train.sh --module <module_name> --config <config_name> --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
6465
```
6566

6667
## Conversion support
6768

6869
### HuggingFace
6970
`torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu.
7071

71-
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.
72+
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--hf_assets_path` if both are set.
7273

7374
2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt.
7475

@@ -84,13 +85,12 @@ python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama-
8485
This guide will walk you through the steps required to convert a checkpoint from `torchtitan` so that it can be loaded into pt format.
8586

8687
1. CHECKPOINT CONFIGURATION
87-
```
88-
[checkpoint]
89-
enable = true
90-
folder = "checkpoint"
91-
interval = 10
92-
last_save_model_only = true
93-
export_dtype = "bfloat16"
88+
```python
89+
checkpoint=CheckpointManager.Config(
90+
interval=10,
91+
last_save_model_only=True,
92+
export_dtype="bfloat16",
93+
),
9494
```
9595

9696
2. SAVE THE FINAL CHECKPOINT\

docs/converging.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ This note clarifies the recommended practices to follow when testing the loss co
1212

1313
## Guidelines
1414

15-
To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.local_batch_size`) in the toml config, and at the same time fix the data parallel degree.
15+
To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.local_batch_size`) in the config_registry function, and at the same time fix the data parallel degree.
1616

1717
If the technique is a parallelism (TP/PP/CP/etc)
18-
- The control set is a 1D FSDP job on `dp` GPUs (or any other verified setups), with a trusted training config (e.g. those under train_configs).
18+
- The control set is a 1D FSDP job on `dp` GPUs (or any other verified setups), with a trusted training config (e.g. those in config_registry.py).
1919
- The minimal test set is a 2D job on `dp*p` GPUs, where `p >= 2` is the degree of the experimented parallelism.
2020
- For some parallelisms, larger `p` may cause larger discrepancies in numeric due to various reasons. For example, current implementation of CP uses `torch.bfloat16` (under default mixed precision training configs) when accumulating intermediate results. A higher `p` is desired to ensure the parallelism works properly, at the cost of more hardware resources.
2121
- Certain parallelisms may impose additional requirements on the batch size. For instance, PP requires local batch size to be at least the number of microbatches (or equivalently, the number of pipeline stages) to reduce bubbles. A valid comparison example would be 1D FSDP on N GPUs with local batch size 8, and 2D FSDP + PP on 4N GPUs (DP N, PP 4) with Interleaved 1F1B schedule (also with local batch size 8), where each PP rank gets two pipeline stages.
@@ -39,7 +39,7 @@ This is a series of loss-converging tests on Llama 3.1, covering both parallelis
3939
Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `torchtitan`.
4040

4141
### Setup
42-
- Base config: [torchtitan/models/llama3/train_configs/llama3_8b.toml](../torchtitan/models/llama3/train_configs/llama3_8b.toml)
42+
- Base config: `llama3_8b` (from [config_registry.py](../torchtitan/models/llama3/config_registry.py))
4343
- `training.local_batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
4444
- `training.data_parallel_shard_degree = 8`, resulting in global batch size 32
4545
- `training.steps = 3000`, `lr_scheduler.warmup_steps = 600`

docs/datasets.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ DATASETS = {
5454
```
5555

5656
### 4. Configure Your Training
57-
In your training configuration file (`.toml`), set your dataset:
57+
In your config_registry function, set your dataset:
5858

59-
```toml
60-
dataset = "wikipedia"
59+
```python
60+
dataloader=HuggingFaceTextDataLoader.Config(
61+
dataset="wikipedia",
62+
),
6163
```
6264

6365
That's it! Your custom dataset is now ready to use with `torchtitan`.

0 commit comments

Comments
 (0)