Commit d4f1e3c
committed
Update on "[2/N] Support lora checkpoint on partial save and multi-source load"
### Summary
- **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
`LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
end of training. When `False` (default), adapter weights are saved separately — use
`checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.
- **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
no-op `finalize()` implementations.
- **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
`export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
`converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
`last_save_in_hf`.
- **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
`remap_lora_keys_from_hf()` handle the bidirectional key translation.
- **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
`llama3_debugmodel_lora` with checkpoint settings for proper resumption.
### Test plan
- [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
partial planner)
- [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
- [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
verify LoRA training runs end-to-end
- [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
LoRA config with PEFT save
* #2484
[ghstack-poisoned]File tree
3 files changed
+21
-15
lines changed- torchtitan
- components
- protocols
3 files changed
+21
-15
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
470 | 470 | | |
471 | 471 | | |
472 | 472 | | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
473 | 479 | | |
474 | 480 | | |
475 | 481 | | |
| |||
684 | 690 | | |
685 | 691 | | |
686 | 692 | | |
687 | | - | |
| 693 | + | |
688 | 694 | | |
689 | 695 | | |
690 | 696 | | |
| |||
1042 | 1048 | | |
1043 | 1049 | | |
1044 | 1050 | | |
1045 | | - | |
| 1051 | + | |
1046 | 1052 | | |
1047 | 1053 | | |
1048 | 1054 | | |
| |||
1053 | 1059 | | |
1054 | 1060 | | |
1055 | 1061 | | |
1056 | | - | |
1057 | | - | |
1058 | | - | |
1059 | | - | |
1060 | | - | |
1061 | | - | |
1062 | 1062 | | |
1063 | 1063 | | |
1064 | 1064 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
119 | 119 | | |
120 | 120 | | |
121 | 121 | | |
122 | | - | |
| 122 | + | |
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
| |||
138 | 138 | | |
139 | 139 | | |
140 | 140 | | |
141 | | - | |
142 | | - | |
| 141 | + | |
| 142 | + | |
143 | 143 | | |
144 | 144 | | |
145 | 145 | | |
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
184 | | - | |
| 184 | + | |
185 | 185 | | |
186 | 186 | | |
187 | 187 | | |
| |||
194 | 194 | | |
195 | 195 | | |
196 | 196 | | |
197 | | - | |
198 | | - | |
| 197 | + | |
| 198 | + | |
199 | 199 | | |
200 | 200 | | |
201 | 201 | | |
| |||
243 | 243 | | |
244 | 244 | | |
245 | 245 | | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
246 | 251 | | |
247 | | - | |
| 252 | + | |
248 | 253 | | |
249 | 254 | | |
250 | 255 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| |||
0 commit comments