[2/N] Support lora checkpoint on partial save and multi-source load#2485
[2/N] Support lora checkpoint on partial save and multi-source load#2485mori360 wants to merge 19 commits intogh/mori360/2/basefrom
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
…urce load" [ghstack-poisoned]
…urce load" [ghstack-poisoned]
…urce load" [ghstack-poisoned]
…urce load" [ghstack-poisoned]
tianyu-l
left a comment
There was a problem hiding this comment.
The new truth table is way too complicated. I have doubts in both
- whether we need to support all various combinations
- if so, whether this is the best flow to support them
Let's start with the following questions:
When user train lora, do they ever need to store anything other than the adapters?
For fault tolerance or changing hyperparameters, users might want to save checkpoint and resume training from previous checkpoint.
- With small scale, loading main weights in HF format, and loading lora weight in HF / DCP is good enough.
- With large scale, for main weight DCP may have advantage over HF, but at large scale why would user use LoRA in the first place?
For inference, I think vllm supports loading LoRA adapters, so again we only need to save lora weights. Please check what format vllm accepts.
So, is it true that we only need to differentiate two cases:
- using LoRA to train, but it's the first time we load an HF checkpoint
- continued LoRA training, where we need to load HF checkpoint and LoRA weights
Do we have the number for this? iirc, the conversion is not cheap though it may be acceptable on a single machine. And this doesn't simplify the logic as we will need to save trainer/optimizer states for fault tolerance purpose. The original design is coupled that with the model weights, which are frozen now. If we are okay with always loading with HF format, even for the fault tolerance ones, we should treat the HF folder as the |
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. `torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter --checkpoint.additional_load_in_hf` ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. `torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter --checkpoint.additional_load_in_hf` ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. `torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter --checkpoint.additional_load_in_hf` ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. `torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter --checkpoint.additional_load_in_hf` ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's converter_key_filter - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when converters are present - Add additional_load_paths config for loading from multiple checkpoint sources - Remove dead cache_state_dict from ModelWrapper - Add save_adapter_only and converter_key_filter to LoRAConverter ### Test Plan - test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected. `torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter --checkpoint.additional_load_in_hf` ### Description The current torchtitan only accept single source save and load. However, for the lora uscase, it could be different: 1. only save the adapter states 2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated. Current logic: 1. if there's local folder, load locally 2. if there's huggingface path, load from huggingface. New logic: introduce addition path(for adapter) Option A: 1. load from primary folder first 2. if not, then huggingface 3. then load from additional path <img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" /> [ghstack-poisoned]
…urce load" ### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
…urce load" ### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
Summary
Consolidate converter-checkpoint integration from 5 scattered dynamic attributes into a single
ConverterCheckpointHooksdataclass, simplifying ModelWrapper and eliminating closure factories in LoRA.Design: ConverterCheckpointHooks
key_filter— Distinguishes converter-added keys (e.g. LoRA's .lora_a., .lora_b.) from base model keys. Used byModelWrapper.state_dict() to save only adapter keys in periodic checkpoints (small, fast saves), and by base_state_dict() to exclude them when loading base weights from HF. Without this, every periodic checkpoint would save the full model.
save_last_fn— Converter-specific save format for the final checkpoint. LoRA uses this to write PEFT-compatible adapter_model.safetensors + adapter_config.json so the result can be loaded directly by HuggingFace PEFT/vLLM. Only called when last_save_in_hf=True. Without this, the final save would use DCP format which external tools can't read.load_additional_fn— Loads converter-specific checkpoint formats during resume. LoRA uses this to load PEFT safetensors and remap keys back to torchtitan naming. Called for secondary checkpoint paths (the adapter checkpoint, separate from the base model HF checkpoint). Without this, resuming from a PEFT save would fail.finalize_fn— End-of-training model transformation before the last save. Runs ModelConvertersContainer.finalize() in reverse order: LoRA merge (fold adapters into base weights) then QAT CONVERT (replace fake-quantized modules with real quantized ones). This is set by ModelConvertersContainer.convert(), not by individual converters. Without this, the final checkpoint would contain training-time artifacts (LoRA adapters, fake quantization) instead of a clean deployable model.Changes
Test plan
pytest tests/unit_tests/test_checkpoint.py -x— newTestModelWrapperConverterKeystests (strict vspartial planner)
pytest tests/unit_tests/test_model_converter.py -x— newtest_lora_key_remap_roundtriptorchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora—verify LoRA training runs end-to-end
torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora— verify 8BLoRA config with PEFT save
Stack from ghstack (oldest at bottom):