Skip to content

Commit d4f1e3c

Browse files
committed
Update on "[2/N] Support lora checkpoint on partial save and multi-source load"
### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
2 parents eae6f72 + ba921e5 commit d4f1e3c

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

torchtitan/components/checkpoint.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,12 @@ def load_state_dict(state_dict):
470470
sd_adapter is not None
471471
), "checkpoint.last_save_in_hf is True, but sd_adapter is not provided."
472472
self.sd_adapter = sd_adapter
473+
474+
# Inject from_hf_map into converter hooks so save/load fns can remap keys
475+
hooks = self.states[MODEL]._get_hooks()
476+
if hooks is not None and sd_adapter is not None:
477+
hooks.from_hf_map = getattr(sd_adapter, "from_hf_map", None)
478+
473479
self.export_dtype = TORCH_DTYPE_MAP[config.export_dtype]
474480
self.exclude_from_loading = config.exclude_from_loading
475481
self.additional_load_paths = config.additional_load_paths
@@ -684,7 +690,7 @@ def dcp_load(
684690
load_fn(
685691
cid,
686692
self.states[MODEL].model,
687-
self._get_from_hf_map(),
693+
hooks,
688694
)
689695
else:
690696
# DCP: load all available states (model + training info).
@@ -1042,7 +1048,7 @@ def _save_last_step(self, curr_step: int) -> None:
10421048
and save_last_fn is not None
10431049
):
10441050
checkpoint_dir = self._create_checkpoint_id(curr_step)
1045-
save_last_fn(states, checkpoint_dir, self._get_from_hf_map())
1051+
save_last_fn(states, checkpoint_dir, hooks)
10461052
return
10471053

10481054
self.dcp_save(
@@ -1053,12 +1059,6 @@ def _save_last_step(self, curr_step: int) -> None:
10531059
to_hf=self.last_save_in_hf,
10541060
)
10551061

1056-
def _get_from_hf_map(self) -> dict[str, str | None] | None:
1057-
"""Return from_hf_map from sd_adapter, or None if unavailable."""
1058-
if self.sd_adapter is None:
1059-
return None
1060-
return getattr(self.sd_adapter, "from_hf_map", None)
1061-
10621062
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
10631063
if not self.enable or self.load_only:
10641064
return False

torchtitan/components/lora.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _save_peft(
119119
self,
120120
state_dict: dict[str, Any],
121121
checkpoint_dir: str,
122-
from_hf_map: dict[str, str | None] | None,
122+
hooks: "ConverterCheckpointHooks",
123123
) -> None:
124124
"""Save adapter weights in PEFT format.
125125
@@ -138,8 +138,8 @@ def _save_peft(
138138
cpu_states[k] = v.cpu() if isinstance(v, torch.Tensor) else v
139139

140140
# Remap keys to HF PEFT naming
141-
if from_hf_map is not None:
142-
hf_states = remap_lora_keys_to_hf(cpu_states, from_hf_map)
141+
if hooks.from_hf_map is not None:
142+
hf_states = remap_lora_keys_to_hf(cpu_states, hooks.from_hf_map)
143143
else:
144144
logger.warning(
145145
"No from_hf_map available; saving PEFT with torchtitan keys."
@@ -181,7 +181,7 @@ def _load_peft(
181181
self,
182182
path: str,
183183
model_parts: list[nn.Module],
184-
from_hf_map: dict[str, str | None] | None,
184+
hooks: "ConverterCheckpointHooks",
185185
) -> None:
186186
"""Load adapter weights from a PEFT directory.
187187
@@ -194,8 +194,8 @@ def _load_peft(
194194

195195
safetensors_path = os.path.join(path, "adapter_model.safetensors")
196196
adapter_sd = load_file(safetensors_path)
197-
if from_hf_map is not None:
198-
adapter_sd = remap_lora_keys_from_hf(adapter_sd, from_hf_map)
197+
if hooks.from_hf_map is not None:
198+
adapter_sd = remap_lora_keys_from_hf(adapter_sd, hooks.from_hf_map)
199199
func = functools.partial(
200200
set_model_state_dict,
201201
model_state_dict=adapter_sd,
@@ -243,8 +243,13 @@ def finalize(self, model: nn.Module) -> None:
243243
for name, mod in list(model.named_modules()):
244244
if not (hasattr(mod, "lora_a") and hasattr(mod, "lora_b")):
245245
continue
246+
assert isinstance(mod, nn.Linear)
247+
lora_a = mod.lora_a
248+
lora_b = mod.lora_b
249+
assert isinstance(lora_a, nn.Linear)
250+
assert isinstance(lora_b, nn.Linear)
246251
with torch.no_grad():
247-
mod.weight.add_(scaling * (mod.lora_b.weight @ mod.lora_a.weight))
252+
mod.weight.add_(scaling * (lora_b.weight @ lora_a.weight))
248253
del mod.lora_a, mod.lora_b
249254
if hasattr(mod, "_lora_scaling"):
250255
del mod._lora_scaling

torchtitan/protocols/model_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ConverterCheckpointHooks:
2323
save_last_fn: Callable | None = None
2424
load_additional_fn: Callable | None = None
2525
finalize_fn: Callable | None = None
26+
from_hf_map: dict[str, str | None] | None = None
2627

2728

2829
class ModelConverter(Protocol):

0 commit comments

Comments
 (0)