Skip to content

Commit 8a322c0

Browse files
committed
lora checkpoint
ghstack-source-id: 4c56897 Pull Request resolved: #2485
1 parent 4d3ff36 commit 8a322c0

File tree

3 files changed

+116
-31
lines changed

3 files changed

+116
-31
lines changed

torchtitan/components/checkpoint.py

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
2727
consolidate_safetensors_files_on_every_rank,
2828
)
29+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
2930
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
3031
from torch.distributed.checkpoint.state_dict import (
3132
get_model_state_dict,
@@ -73,8 +74,33 @@ def _get_state_dict(self) -> dict[str, Any]:
7374
}
7475
return state_dict
7576

77+
def _is_converter_key(self, key: str) -> bool:
78+
"""Check if a state dict key was added by a model converter."""
79+
for part in self.model:
80+
fn = getattr(part, "converter_key_filter", None)
81+
if fn is not None and fn(key):
82+
return True
83+
return False
84+
85+
def _save_converter_keys_only(self) -> bool:
86+
"""Check if any model part requests saving only converter-added weights."""
87+
return any(
88+
getattr(part, "save_converter_keys_only", False) for part in self.model
89+
)
90+
91+
def state_dict_to_save(self) -> dict[str, Any]:
92+
full_sd = self._get_state_dict()
93+
if self._save_converter_keys_only():
94+
return {k: v for k, v in full_sd.items() if self._is_converter_key(k)}
95+
return full_sd
96+
97+
def base_state_dict(self) -> dict[str, Any]:
98+
"""Return state dict with only the original model keys (before converters)."""
99+
full_sd = self._get_state_dict()
100+
return {k: v for k, v in full_sd.items() if not self._is_converter_key(k)}
101+
76102
def state_dict(self) -> dict[str, Any]:
77-
return self.cache_state_dict
103+
return self.state_dict_to_save()
78104

79105
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
80106
func = functools.partial(
@@ -321,6 +347,14 @@ class Config(Configurable.Config):
321347
This will load the model only, excluding the specified keys.
322348
"""
323349

350+
additional_load_paths: list[str] = field(default_factory=list)
351+
"""
352+
Additional checkpoint paths to load from after the primary checkpoint.
353+
Useful for loading state dicts from multiple sources, e.g., base model
354+
weights from one checkpoint and LoRA adapter weights from another.
355+
Each path should contain a valid DCP checkpoint directory.
356+
"""
357+
324358
enable_first_step_checkpoint: bool = False
325359
"""
326360
Enable the checkpoint save at first step. This will save a checkpoint immediately
@@ -445,6 +479,7 @@ def load_state_dict(state_dict):
445479
self.sd_adapter = sd_adapter
446480
self.export_dtype = TORCH_DTYPE_MAP[config.export_dtype]
447481
self.exclude_from_loading = config.exclude_from_loading
482+
self.additional_load_paths = config.additional_load_paths
448483
self.interval = config.interval
449484
self.enable_first_step_checkpoint = config.enable_first_step_checkpoint
450485

@@ -600,41 +635,63 @@ def dcp_save(
600635
def dcp_load(
601636
self,
602637
state_dict: dict[str, Any],
603-
checkpoint_id: str,
638+
checkpoint_id: str | list[str],
604639
from_hf: bool,
605640
from_quantized: bool,
606641
) -> None:
607-
"""Load the checkpoint with dcp.
642+
"""Load the checkpoint(s) with dcp.
643+
608644
Args:
609645
state_dict (dict): The state dict to load.
610-
checkpoint_id (str): The checkpoint id to load.
611-
from_hf (bool): Whether to load from HuggingFace checkpoint with
612-
its own model definition and safetensors format.
646+
checkpoint_id (str | list[str]): The checkpoint id(s) to load.
647+
The first checkpoint is treated as the primary checkpoint.
648+
Additional checkpoints are always in DCP format.
649+
from_hf (bool): Whether to load the primary checkpoint from
650+
HuggingFace safetensors format.
651+
from_quantized (bool): Whether the HuggingFace checkpoint is quantized.
613652
"""
653+
checkpoint_ids = (
654+
[checkpoint_id] if isinstance(checkpoint_id, str) else checkpoint_id
655+
)
656+
# planner = (
657+
# DefaultLoadPlanner(allow_partial_load=True)
658+
# if len(checkpoint_ids) > 1
659+
# else DefaultLoadPlanner()
660+
# )
661+
planner = DefaultLoadPlanner(allow_partial_load=True)
614662

615-
if from_hf:
616-
assert (
617-
self.sd_adapter is not None
618-
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
619-
hf_state_dict = self.sd_adapter.to_hf(state_dict)
620-
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
621-
checkpoint_id, from_quantized
622-
)
623-
624-
dcp.load(
625-
hf_state_dict,
626-
storage_reader=hf_storage_reader,
627-
)
663+
for i, cid in enumerate(checkpoint_ids):
664+
is_primary = i == 0
628665

629-
state_dict = self.sd_adapter.from_hf(hf_state_dict)
630-
self.states[MODEL].load_state_dict(state_dict)
631-
else:
632-
dcp.load(state_dict, checkpoint_id=checkpoint_id)
633-
634-
# TODO: Since we flatten the model states in state_dict, we need to
635-
# manually call load_state_dict() for the model. Need to fix this.
636-
if MODEL in self.states:
637-
self.states[MODEL].load_state_dict(state_dict)
666+
if is_primary:
667+
if from_hf:
668+
# HF format: model only, training states from additional checkpoints
669+
assert (
670+
self.sd_adapter is not None
671+
), "Trying to load HF safetensors but sd_adapter is not provided."
672+
hf_state_dict = self.sd_adapter.to_hf(
673+
self.states[MODEL].base_state_dict()
674+
)
675+
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
676+
cid, from_quantized
677+
)
678+
dcp.load(
679+
hf_state_dict,
680+
storage_reader=hf_storage_reader,
681+
planner=planner,
682+
)
683+
converted_sd = self.sd_adapter.from_hf(hf_state_dict)
684+
if MODEL in self.states:
685+
self.states[MODEL].load_state_dict(converted_sd)
686+
else:
687+
dcp.load(state_dict, checkpoint_id=cid, planner=planner)
688+
if MODEL in self.states:
689+
self.states[MODEL].load_state_dict(state_dict)
690+
else:
691+
# Additional checkpoints: always DCP format, load all available states
692+
dcp.load(state_dict, checkpoint_id=cid, planner=planner)
693+
if MODEL in self.states:
694+
self.states[MODEL].load_state_dict(state_dict)
638695

639696
@torch.no_grad()
640697
def save(self, curr_step: int, last_step: bool = False) -> None:
@@ -737,6 +794,12 @@ def load(self, step: int = -1) -> bool:
737794
if not self.enable:
738795
return False
739796

797+
for path in self.additional_load_paths:
798+
if not os.path.isdir(path):
799+
raise ValueError(
800+
f"checkpoint.additional_load_paths contains invalid path: {path}"
801+
)
802+
740803
model_only = False
741804
from_hf = False
742805
from_quantized = False
@@ -808,7 +871,7 @@ def load(self, step: int = -1) -> bool:
808871
states = self._states_to_load(model_only)
809872
self.dcp_load(
810873
states,
811-
checkpoint_id=checkpoint_id,
874+
checkpoint_id=[checkpoint_id] + self.additional_load_paths,
812875
from_hf=from_hf,
813876
from_quantized=from_quantized,
814877
)
@@ -947,7 +1010,7 @@ def _save_last_step(self, curr_step: int) -> None:
9471010
# is not the same as the export dtype at the end of the training.
9481011

9491012
if self.last_save_model_only:
950-
states = self.states[MODEL].state_dict()
1013+
states = self.states[MODEL].state_dict_to_save()
9511014

9521015
if self.export_dtype != torch.float32:
9531016
states = {k: v.to(self.export_dtype) for k, v in states.items()}

torchtitan/components/lora.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
1413
from torchtitan.config import Configurable
1514
from torchtitan.tools.logging import logger
1615

@@ -92,9 +91,15 @@ class Config(Configurable.Config):
9291
alpha: float = 16.0
9392
"""Scaling factor. Output is scaled by alpha/rank."""
9493

94+
save_adapter_only: bool = True
95+
"""If True, only save LoRA adapter weights in checkpoints.
96+
Requires base model to be loaded from HF/initial_load_path on resume.
97+
Set to False to save full model weights for debugging without pretrained base."""
98+
9599
def __init__(self, config: Config, **kwargs):
96100
self.rank = config.rank
97101
self.alpha = config.alpha
102+
self.save_adapter_only = config.save_adapter_only
98103
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
99104

100105
def convert(self, model: nn.Module) -> None:
@@ -120,6 +125,15 @@ def new_model_init_weights(*args: Any, **kwargs: Any) -> None:
120125

121126
object.__setattr__(module, "init_weights", new_model_init_weights)
122127

128+
# Expose a key filter and flag on the module so ModelWrapper can
129+
# partition the state dict without knowing about LoRA internals.
130+
def converter_key_filter(key: str) -> bool:
131+
"""Return True if key was added by this converter (LoRA adapter weights)."""
132+
return ".lora_a." in key or ".lora_b." in key
133+
134+
object.__setattr__(module, "converter_key_filter", converter_key_filter)
135+
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)
136+
123137
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
124138
pass
125139

torchtitan/models/llama3/config_registry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ def llama3_debugmodel_lora() -> Trainer.Config:
119119
),
120120
],
121121
)
122+
# For LoRA finetuning, set initial_load_in_hf=True to enable proper
123+
# checkpoint resumption (load base model from HF, then load LoRA adapters)
124+
config.checkpoint = CheckpointManager.Config(
125+
interval=500,
126+
initial_load_in_hf=True,
127+
initial_load_model_only=True,
128+
last_save_model_only=False,
129+
)
122130
return config
123131

124132

0 commit comments

Comments
 (0)