Skip to content

[2/N] Support lora checkpoint on partial save and multi-source load#2485

Draft
mori360 wants to merge 19 commits intogh/mori360/2/basefrom
gh/mori360/2/head
Draft

[2/N] Support lora checkpoint on partial save and multi-source load#2485
mori360 wants to merge 19 commits intogh/mori360/2/basefrom
gh/mori360/2/head

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Mar 4, 2026

Summary

Consolidate converter-checkpoint integration from 5 scattered dynamic attributes into a single ConverterCheckpointHooks dataclass, simplifying ModelWrapper and eliminating closure factories in LoRA.

Design: ConverterCheckpointHooks

  class ConverterCheckpointHooks:
      key_filter: Callable[[str], bool] | None = None
      save_last_fn: Callable | None = None
      load_additional_fn: Callable | None = None
      finalize_fn: Callable | None = None
  • key_filter — Distinguishes converter-added keys (e.g. LoRA's .lora_a., .lora_b.) from base model keys. Used by
    ModelWrapper.state_dict() to save only adapter keys in periodic checkpoints (small, fast saves), and by base_state_dict() to exclude them when loading base weights from HF. Without this, every periodic checkpoint would save the full model.
  • save_last_fn — Converter-specific save format for the final checkpoint. LoRA uses this to write PEFT-compatible adapter_model.safetensors + adapter_config.json so the result can be loaded directly by HuggingFace PEFT/vLLM. Only called when last_save_in_hf=True. Without this, the final save would use DCP format which external tools can't read.
  • load_additional_fn — Loads converter-specific checkpoint formats during resume. LoRA uses this to load PEFT safetensors and remap keys back to torchtitan naming. Called for secondary checkpoint paths (the adapter checkpoint, separate from the base model HF checkpoint). Without this, resuming from a PEFT save would fail.
  • finalize_fn — End-of-training model transformation before the last save. Runs ModelConvertersContainer.finalize() in reverse order: LoRA merge (fold adapters into base weights) then QAT CONVERT (replace fake-quantized modules with real quantized ones). This is set by ModelConvertersContainer.convert(), not by individual converters. Without this, the final checkpoint would contain training-time artifacts (LoRA adapters, fake quantization) instead of a clean deployable model.
Screenshot 2026-03-20 at 3 45 01 PM

Changes

  • model_converter.py: Added ConverterCheckpointHooks dataclass. ModelConvertersContainer.convert() merges into existing hooks.
  • lora.py: Replaced closure factories (_make_peft_save_fn, _make_peft_load_fn, _make_merge_fn) with regular methods (_save_peft, _load_peft). convert() builds hooks directly. finalize() cleans up hooks after merge.
  • checkpoint.py: ModelWrapper simplified from 7 methods to 5. Added get_hooks() as single access point. Removed scattered getattr(part, "converter*") reads.
  • qat.py: Extracted apply_qat_prepare() shared by QATConverter.convert() and LoRAConverter._apply_adapter_qat().

Test plan

  • pytest tests/unit_tests/test_checkpoint.py -x — new TestModelWrapperConverterKeys tests (strict vs
    partial planner)
  • pytest tests/unit_tests/test_model_converter.py -x — new test_lora_key_remap_roundtrip
  • torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora
    verify LoRA training runs end-to-end
  • torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora — verify 8B
    LoRA config with PEFT save

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: b46526e
Pull Request resolved: #2485
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 4, 2026
[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: 8eaef4c
Pull Request resolved: #2485
@mori360 mori360 changed the title lora checkpoint [2/N] Support lora checkpoint on partial save and multi-source load Mar 4, 2026
@mori360 mori360 marked this pull request as draft March 6, 2026 21:34
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new truth table is way too complicated. I have doubts in both

  • whether we need to support all various combinations
  • if so, whether this is the best flow to support them

Let's start with the following questions:
When user train lora, do they ever need to store anything other than the adapters?

For fault tolerance or changing hyperparameters, users might want to save checkpoint and resume training from previous checkpoint.

  • With small scale, loading main weights in HF format, and loading lora weight in HF / DCP is good enough.
  • With large scale, for main weight DCP may have advantage over HF, but at large scale why would user use LoRA in the first place?

For inference, I think vllm supports loading LoRA adapters, so again we only need to save lora weights. Please check what format vllm accepts.

So, is it true that we only need to differentiate two cases:

  • using LoRA to train, but it's the first time we load an HF checkpoint
  • continued LoRA training, where we need to load HF checkpoint and LoRA weights

@fegin
Copy link
Contributor

fegin commented Mar 9, 2026

With small scale, loading main weights in HF format, and loading lora weight in HF / DCP is good enough.

Do we have the number for this? iirc, the conversion is not cheap though it may be acceptable on a single machine.

And this doesn't simplify the logic as we will need to save trainer/optimizer states for fault tolerance purpose. The original design is coupled that with the model weights, which are frozen now.

If we are okay with always loading with HF format, even for the fault tolerance ones, we should treat the HF folder as the additional folder. lora adapter + trainer + optimizer are in the original checkpoint folder. cc., @mori360

mori360 added 4 commits March 12, 2026 12:45
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
@mori360 mori360 mentioned this pull request Mar 13, 2026
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
mori360 added 8 commits March 16, 2026 20:14
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

`torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter  --checkpoint.additional_load_in_hf`

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

`torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter  --checkpoint.additional_load_in_hf`

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

`torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter  --checkpoint.additional_load_in_hf`

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

`torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter  --checkpoint.additional_load_in_hf`

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

### Summary

  - Add converter-aware checkpoint save/load to ModelWrapper so LoRA adapter weights can be saved separately from base model weights
  - ModelWrapper gains has_converter_keys(), state_dict_to_save(), and base_state_dict() to partition state dicts based on the converter's
  converter_key_filter
  - dcp_load supports multi-source loading (e.g., base model from HF + LoRA adapters from DCP) with DefaultLoadPlanner(allow_partial_load=True) when
  converters are present
  - Add additional_load_paths config for loading from multiple checkpoint sources
  - Remove dead cache_state_dict from ModelWrapper
  - Add save_adapter_only and converter_key_filter to LoRAConverter

### Test Plan
 -  test_load_uses_strict/partial_planner_with_converter: creates a CheckpointManager with a plain nn.Linear model (no converters). Saves step 1, then loads step 1. Inspects the planner kwarg passed to dcp.load — asserts it's a DefaultLoadPlanner with allow_partial_load=False/True as expected.

`torchrun --nproc_per_node=4 --module torchtitan.train --module llama3 --config llama3_8b_lora --checkpoint.enable --checkpoint.initial_load_path /path/to/load/model --checkpoint.additional_load_paths /path/to/load/adapter  --checkpoint.additional_load_in_hf`

### Description
The current torchtitan only accept single source save and load.
However, for the lora uscase, it could be different:
1. only save the adapter states
2. load model state from 1st source(e.g. huggingface), load adapter from 2nd source
The requirement here need checkpoint system to accept multi-source load, and the logic here could be complicated.

Current logic:
1. if there's local folder, load locally
2. if there's huggingface path, load from huggingface.

New logic:
introduce addition path(for adapter)
Option A:
1. load from primary folder first
2. if not, then huggingface
3. then load from additional path

<img width="660" height="604" alt="Screenshot 2026-03-06 at 1 29 27 PM" src="https://github.com/user-attachments/assets/5b3abc0d-0378-497f-bce3-bc037b2d64ae" />







[ghstack-poisoned]
…urce load"

 ### Summary

  - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
  `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
  end of training. When `False` (default), adapter weights are saved separately — use
  `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.

  - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
  checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
  quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
  so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
  no-op `finalize()` implementations.

  - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
  `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
  `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
  list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
  config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
  `last_save_in_hf`.

  - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
   with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
  `remap_lora_keys_from_hf()` handle the bidirectional key translation.

  - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
  `llama3_debugmodel_lora` with checkpoint settings for proper resumption.

  ### Test plan

  - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
   partial planner)
  - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
  verify LoRA training runs end-to-end
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
  LoRA config with PEFT save


* #2484

[ghstack-poisoned]
…urce load"

 ### Summary

  - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
  `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
  end of training. When `False` (default), adapter weights are saved separately — use
  `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.

  - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
  checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
  quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
  so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
  no-op `finalize()` implementations.

  - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
  `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
  `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
  list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
  config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
  `last_save_in_hf`.

  - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
   with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
  `remap_lora_keys_from_hf()` handle the bidirectional key translation.

  - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
  `llama3_debugmodel_lora` with checkpoint settings for proper resumption.

  ### Test plan

  - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
   partial planner)
  - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
  verify LoRA training runs end-to-end
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
  LoRA config with PEFT save


* #2484

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants