Skip to content

Commit ba921e5

Browse files
committed
Update base for Update on "[2/N] Support lora checkpoint on partial save and multi-source load"
### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
1 parent 9df38df commit ba921e5

File tree

0 file changed

+0
-0
lines changed

    0 file changed

    +0
    -0
    lines changed

    0 commit comments

    Comments
     (0)