Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2ce36ba
lora checkpoint
mori360 Mar 4, 2026
0f8e5e3
Update on "lora checkpoint"
mori360 Mar 4, 2026
e83c45b
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 4, 2026
411373d
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 5, 2026
541078e
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 5, 2026
09c39da
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 5, 2026
5c3b0b6
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 12, 2026
a242d7c
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 12, 2026
7830ae6
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 12, 2026
27e605e
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 12, 2026
0b83f01
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 17, 2026
d4364d1
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 17, 2026
b79e7d7
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 17, 2026
315e91d
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 17, 2026
55bc894
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 19, 2026
4cc38d0
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 20, 2026
5f2482e
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 20, 2026
eae6f72
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 20, 2026
d4f1e3c
Update on "[2/N] Support lora checkpoint on partial save and multi-so…
mori360 Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.nn as nn
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.state_dict_saver import AsyncSaveResponse
from torch.utils.data import DataLoader
from torchtitan.components.checkpoint import CheckpointManager
Expand Down Expand Up @@ -165,7 +166,7 @@ def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
sd_to_save[key] = val
torch.save(sd_to_save, os.path.join(checkpoint_id, "state_dict.pt"))

def fake_load(self, states: dict, checkpoint_id=None):
def fake_load(self, states: dict, checkpoint_id=None, **kwargs):
path = os.path.join(checkpoint_id, "state_dict.pt")
loaded = torch.load(path, weights_only="False")
for key, val in loaded.items():
Expand Down Expand Up @@ -748,7 +749,7 @@ def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
self.assertNotIn("optimizer", state_dict)
return

def fake_load(state_dict: dict, checkpoint_id=None):
def fake_load(state_dict: dict, checkpoint_id=None, **kwargs):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
Expand Down Expand Up @@ -776,5 +777,110 @@ def fake_load(state_dict: dict, checkpoint_id=None):
manager.load(step=1)


class TestModelWrapperConverterKeys(unittest.TestCase):
"""Tests for ModelWrapper.has_converter_keys() and its effect on load planner."""

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_load_uses_strict_planner_without_converter(
self, mock_save, mock_load, mock_rank
):
"""Without converter keys, dcp.load is called with allow_partial_load=False."""
mock_save.side_effect = lambda *a, **kw: os.makedirs(
kw.get("checkpoint_id", a[1] if len(a) > 1 else ""), exist_ok=True
)
mock_load.side_effect = lambda *a, **kw: None

temp_dir = tempfile.mkdtemp()
try:
model = nn.Linear(2, 2)
cfg = CheckpointManager.Config(
enable=True,
async_mode="disabled",
folder="",
interval=1,
keep_latest_k=0,
last_save_model_only=False,
export_dtype="float32",
exclude_from_loading=[],
initial_load_path=None,
initial_load_model_only=False,
)
with mock.patch("torch.distributed.new_group", return_value="pg"):
manager = CheckpointManager(
dataloader=FakeDataLoader(),
model_parts=[model],
optimizers=FakeOptimizersContainer(),
lr_schedulers=FakeLRSchedulersContainer(),
states={},
config=cfg,
sd_adapter=None,
base_folder=temp_dir,
ft_manager=DummyFTManager(),
)
manager.save(curr_step=1)
manager.load(step=1)

_, kwargs = mock_load.call_args
planner = kwargs.get("planner")
self.assertIsInstance(planner, DefaultLoadPlanner)
self.assertFalse(planner.allow_partial_load)
finally:
shutil.rmtree(temp_dir)

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_load_uses_partial_planner_with_converter(
self, mock_save, mock_load, mock_rank
):
"""With converter keys on the model, dcp.load is called with allow_partial_load=True."""
mock_save.side_effect = lambda *a, **kw: os.makedirs(
kw.get("checkpoint_id", a[1] if len(a) > 1 else ""), exist_ok=True
)
mock_load.side_effect = lambda *a, **kw: None

temp_dir = tempfile.mkdtemp()
try:
model = nn.Linear(2, 2)
object.__setattr__(
model, "converter_key_filter", lambda key: ".lora_a." in key
)
cfg = CheckpointManager.Config(
enable=True,
async_mode="disabled",
folder="",
interval=1,
keep_latest_k=0,
last_save_model_only=False,
export_dtype="float32",
exclude_from_loading=[],
initial_load_path=None,
initial_load_model_only=False,
)
with mock.patch("torch.distributed.new_group", return_value="pg"):
manager = CheckpointManager(
dataloader=FakeDataLoader(),
model_parts=[model],
optimizers=FakeOptimizersContainer(),
lr_schedulers=FakeLRSchedulersContainer(),
states={},
config=cfg,
sd_adapter=None,
base_folder=temp_dir,
ft_manager=DummyFTManager(),
)
manager.save(curr_step=1)
manager.load(step=1)

_, kwargs = mock_load.call_args
planner = kwargs.get("planner")
self.assertIsInstance(planner, DefaultLoadPlanner)
self.assertTrue(planner.allow_partial_load)
finally:
shutil.rmtree(temp_dir)


if __name__ == "__main__":
unittest.main()
132 changes: 98 additions & 34 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
)
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
Expand Down Expand Up @@ -65,16 +66,47 @@ class AsyncMode(str, enum.Enum):
class ModelWrapper(Stateful):
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict()

def _get_state_dict(self) -> dict[str, Any]:
state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
return state_dict

def _is_converter_key(self, key: str) -> bool:
"""Check if a state dict key was added by a model converter."""
for part in self.model:
fn = getattr(part, "converter_key_filter", None)
if fn is not None and fn(key):
return True
return False

def has_converter_keys(self) -> bool:
"""Check if any model part has converter-added keys (e.g. LoRA adapters)."""
return any(
getattr(part, "converter_key_filter", None) is not None
for part in self.model
)

def _save_converter_keys_only(self) -> bool:
"""Check if any model part requests saving only converter-added weights."""
return any(
getattr(part, "save_converter_keys_only", False) for part in self.model
)

def state_dict_to_save(self) -> dict[str, Any]:
full_sd = self._get_state_dict()
if self._save_converter_keys_only():
return {k: v for k, v in full_sd.items() if self._is_converter_key(k)}
return full_sd

def base_state_dict(self) -> dict[str, Any]:
"""Return state dict with only the original model keys (before converters)."""
full_sd = self._get_state_dict()
return {k: v for k, v in full_sd.items() if not self._is_converter_key(k)}

def state_dict(self) -> dict[str, Any]:
return self.cache_state_dict
return self.state_dict_to_save()

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
func = functools.partial(
Expand All @@ -83,9 +115,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
options=StateDictOptions(strict=False),
)
list(map(func, self.model))
# `set_model_state_dict()` does change the keys of the input state_dict,
# we will need to reinitialize the cache_state_dict.
self.cache_state_dict = self._get_state_dict()


class Terminate:
Expand Down Expand Up @@ -321,6 +350,14 @@ class Config(Configurable.Config):
This will load the model only, excluding the specified keys.
"""

additional_load_paths: list[str] = field(default_factory=list)
"""
Additional checkpoint paths to load from after the primary checkpoint.
Useful for loading state dicts from multiple sources, e.g., base model
weights from one checkpoint and LoRA adapter weights from another.
Each path should contain a valid DCP checkpoint directory.
"""

enable_first_step_checkpoint: bool = False
"""
Enable the checkpoint save at first step. This will save a checkpoint immediately
Expand Down Expand Up @@ -445,6 +482,7 @@ def load_state_dict(state_dict):
self.sd_adapter = sd_adapter
self.export_dtype = TORCH_DTYPE_MAP[config.export_dtype]
self.exclude_from_loading = config.exclude_from_loading
self.additional_load_paths = config.additional_load_paths
self.interval = config.interval
self.enable_first_step_checkpoint = config.enable_first_step_checkpoint

Expand Down Expand Up @@ -600,41 +638,61 @@ def dcp_save(
def dcp_load(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
checkpoint_id: str | list[str],
from_hf: bool,
from_quantized: bool,
) -> None:
"""Load the checkpoint with dcp.
"""Load the checkpoint(s) with dcp.

Args:
state_dict (dict): The state dict to load.
checkpoint_id (str): The checkpoint id to load.
from_hf (bool): Whether to load from HuggingFace checkpoint with
its own model definition and safetensors format.
checkpoint_id (str | list[str]): The checkpoint id(s) to load.
The first checkpoint is treated as the primary checkpoint.
Additional checkpoints are always in DCP format.
from_hf (bool): Whether to load the primary checkpoint from
HuggingFace safetensors format.
from_quantized (bool): Whether the HuggingFace checkpoint is quantized.
"""
checkpoint_ids = (
[checkpoint_id] if isinstance(checkpoint_id, str) else checkpoint_id
)
needs_partial = (
len(checkpoint_ids) > 1 or self.states[MODEL].has_converter_keys()
)
planner = DefaultLoadPlanner(allow_partial_load=needs_partial)

if from_hf:
assert (
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
checkpoint_id, from_quantized
)

dcp.load(
hf_state_dict,
storage_reader=hf_storage_reader,
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
self.states[MODEL].load_state_dict(state_dict)
else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)
for i, cid in enumerate(checkpoint_ids):
is_primary = i == 0

# TODO: Since we flatten the model states in state_dict, we need to
# manually call load_state_dict() for the model. Need to fix this.
if MODEL in self.states:
self.states[MODEL].load_state_dict(state_dict)
if is_primary:
if from_hf:
# HF format: model only, training states from additional checkpoints
assert (
self.sd_adapter is not None
), "Trying to load HF safetensors but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(
self.states[MODEL].base_state_dict()
)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
cid, from_quantized
)
dcp.load(
hf_state_dict,
storage_reader=hf_storage_reader,
planner=planner,
)
converted_sd = self.sd_adapter.from_hf(hf_state_dict)
if MODEL in self.states:
self.states[MODEL].load_state_dict(converted_sd)
else:
dcp.load(state_dict, checkpoint_id=cid, planner=planner)
if MODEL in self.states:
self.states[MODEL].load_state_dict(state_dict)
else:
# Additional checkpoints: always DCP format, load all available states
dcp.load(state_dict, checkpoint_id=cid, planner=planner)
if MODEL in self.states:
self.states[MODEL].load_state_dict(state_dict)

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
Expand Down Expand Up @@ -737,6 +795,12 @@ def load(self, step: int = -1) -> bool:
if not self.enable:
return False

for path in self.additional_load_paths:
if not os.path.isdir(path):
raise ValueError(
f"checkpoint.additional_load_paths contains invalid path: {path}"
)

model_only = False
from_hf = False
from_quantized = False
Expand Down Expand Up @@ -808,7 +872,7 @@ def load(self, step: int = -1) -> bool:
states = self._states_to_load(model_only)
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
checkpoint_id=[checkpoint_id] + self.additional_load_paths,
from_hf=from_hf,
from_quantized=from_quantized,
)
Expand Down Expand Up @@ -947,7 +1011,7 @@ def _save_last_step(self, curr_step: int) -> None:
# is not the same as the export dtype at the end of the training.

if self.last_save_model_only:
states = self.states[MODEL].state_dict()
states = self.states[MODEL].state_dict_to_save()

if self.export_dtype != torch.float32:
states = {k: v.to(self.export_dtype) for k, v in states.items()}
Expand Down
16 changes: 15 additions & 1 deletion torchtitan/components/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
import torch.nn as nn

from torchtitan.config import Configurable
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -91,9 +90,15 @@ class Config(Configurable.Config):
alpha: float = 16.0
"""Scaling factor. Output is scaled by alpha/rank."""

save_adapter_only: bool = True
"""If True, only save LoRA adapter weights in checkpoints.
Requires base model to be loaded from HF/initial_load_path on resume.
Set to False to save full model weights for debugging without pretrained base."""

def __init__(self, config: Config, **kwargs):
self.rank = config.rank
self.alpha = config.alpha
self.save_adapter_only = config.save_adapter_only
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")

def convert(self, model: nn.Module) -> None:
Expand All @@ -119,5 +124,14 @@ def new_model_init_weights(*args: Any, **kwargs: Any) -> None:

object.__setattr__(module, "init_weights", new_model_init_weights)

# Expose a key filter and flag on the module so ModelWrapper can
# partition the state dict without knowing about LoRA internals.
def converter_key_filter(key: str) -> bool:
"""Return True if key was added by this converter (LoRA adapter weights)."""
return ".lora_a." in key or ".lora_b." in key

object.__setattr__(module, "converter_key_filter", converter_key_filter)
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
pass
Loading
Loading