Skip to content

Commit 0a1f1aa

Browse files
committed
lora checkpoint
ghstack-source-id: b9f7bf5 Pull Request resolved: #2485
1 parent bbaeed5 commit 0a1f1aa

File tree

5 files changed

+540
-36
lines changed

5 files changed

+540
-36
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.nn as nn
17+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
1718
from torch.distributed.checkpoint.state_dict_saver import AsyncSaveResponse
1819
from torch.utils.data import DataLoader
1920
from torchtitan.components.checkpoint import CheckpointManager
@@ -165,7 +166,7 @@ def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
165166
sd_to_save[key] = val
166167
torch.save(sd_to_save, os.path.join(checkpoint_id, "state_dict.pt"))
167168

168-
def fake_load(self, states: dict, checkpoint_id=None):
169+
def fake_load(self, states: dict, checkpoint_id=None, **kwargs):
169170
path = os.path.join(checkpoint_id, "state_dict.pt")
170171
loaded = torch.load(path, weights_only="False")
171172
for key, val in loaded.items():
@@ -748,7 +749,7 @@ def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
748749
self.assertNotIn("optimizer", state_dict)
749750
return
750751

751-
def fake_load(state_dict: dict, checkpoint_id=None):
752+
def fake_load(state_dict: dict, checkpoint_id=None, **kwargs):
752753
self.assertIn("bias", state_dict)
753754
self.assertIn("weight", state_dict)
754755
# No model prefix
@@ -776,5 +777,86 @@ def fake_load(state_dict: dict, checkpoint_id=None):
776777
manager.load(step=1)
777778

778779

780+
class TestModelWrapperConverterKeys(unittest.TestCase):
781+
"""Tests for ModelWrapper.has_converter_keys() and its effect on load planner."""
782+
783+
def _create_manager(self, mock_save, mock_load, model, temp_dir):
784+
"""Create a CheckpointManager with mocked dcp.save/load."""
785+
mock_save.side_effect = lambda *a, **kw: os.makedirs(
786+
kw.get("checkpoint_id", a[1] if len(a) > 1 else ""), exist_ok=True
787+
)
788+
mock_load.side_effect = lambda *a, **kw: None
789+
790+
cfg = CheckpointManager.Config(
791+
enable=True,
792+
async_mode="disabled",
793+
folder="",
794+
interval=1,
795+
keep_latest_k=0,
796+
last_save_model_only=False,
797+
export_dtype="float32",
798+
exclude_from_loading=[],
799+
initial_load_path=None,
800+
initial_load_model_only=False,
801+
)
802+
with mock.patch("torch.distributed.new_group", return_value="pg"):
803+
return CheckpointManager(
804+
dataloader=FakeDataLoader(),
805+
model_parts=[model],
806+
optimizers=FakeOptimizersContainer(),
807+
lr_schedulers=FakeLRSchedulersContainer(),
808+
states={},
809+
config=cfg,
810+
sd_adapter=None,
811+
base_folder=temp_dir,
812+
ft_manager=DummyFTManager(),
813+
)
814+
815+
@mock.patch("torch.distributed.get_rank", return_value=0)
816+
@mock.patch("torchtitan.components.checkpoint.dcp.load")
817+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
818+
def test_load_uses_strict_planner_without_converter(
819+
self, mock_save, mock_load, mock_rank
820+
):
821+
"""Without converter keys, dcp.load is called with allow_partial_load=False."""
822+
temp_dir = tempfile.mkdtemp()
823+
try:
824+
model = nn.Linear(2, 2)
825+
manager = self._create_manager(mock_save, mock_load, model, temp_dir)
826+
manager.save(curr_step=1)
827+
manager.load(step=1)
828+
829+
_, kwargs = mock_load.call_args
830+
planner = kwargs.get("planner")
831+
self.assertIsInstance(planner, DefaultLoadPlanner)
832+
self.assertFalse(planner.allow_partial_load)
833+
finally:
834+
shutil.rmtree(temp_dir)
835+
836+
@mock.patch("torch.distributed.get_rank", return_value=0)
837+
@mock.patch("torchtitan.components.checkpoint.dcp.load")
838+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
839+
def test_load_uses_partial_planner_with_converter(
840+
self, mock_save, mock_load, mock_rank
841+
):
842+
"""With converter keys on the model, dcp.load is called with allow_partial_load=True."""
843+
temp_dir = tempfile.mkdtemp()
844+
try:
845+
model = nn.Linear(2, 2)
846+
object.__setattr__(
847+
model, "converter_key_filter", lambda key: ".lora_a." in key
848+
)
849+
manager = self._create_manager(mock_save, mock_load, model, temp_dir)
850+
manager.save(curr_step=1)
851+
manager.load(step=1)
852+
853+
_, kwargs = mock_load.call_args
854+
planner = kwargs.get("planner")
855+
self.assertIsInstance(planner, DefaultLoadPlanner)
856+
self.assertTrue(planner.allow_partial_load)
857+
finally:
858+
shutil.rmtree(temp_dir)
859+
860+
779861
if __name__ == "__main__":
780862
unittest.main()

tests/unit_tests/test_model_converter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,32 @@ def test_lora_trains_base_frozen():
173173
if name in lora_before
174174
)
175175
assert any_lora_changed, "No LoRA param changed after 5 training steps"
176+
177+
178+
def test_lora_key_remap_roundtrip():
179+
"""Remap torchtitan LoRA keys to HF and back, verify roundtrip."""
180+
from torchtitan.components.lora import (
181+
remap_lora_keys_from_hf,
182+
remap_lora_keys_to_hf,
183+
)
184+
185+
from_hf_map = {
186+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
187+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
188+
}
189+
190+
tt_sd = {
191+
"layers.0.attention.wq.lora_a.weight": torch.randn(8, 64),
192+
"layers.0.attention.wq.lora_b.weight": torch.randn(64, 8),
193+
"layers.2.feed_forward.w1.lora_a.weight": torch.randn(8, 64),
194+
}
195+
196+
hf_sd = remap_lora_keys_to_hf(tt_sd, from_hf_map)
197+
assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" in hf_sd
198+
assert "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight" in hf_sd
199+
assert "base_model.model.model.layers.2.mlp.gate_proj.lora_A.weight" in hf_sd
200+
201+
rt_sd = remap_lora_keys_from_hf(hf_sd, from_hf_map)
202+
assert set(rt_sd.keys()) == set(tt_sd.keys())
203+
for k in tt_sd:
204+
assert torch.equal(rt_sd[k], tt_sd[k])

0 commit comments

Comments
 (0)