|
14 | 14 |
|
15 | 15 | import torch |
16 | 16 | import torch.nn as nn |
| 17 | +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner |
17 | 18 | from torch.distributed.checkpoint.state_dict_saver import AsyncSaveResponse |
18 | 19 | from torch.utils.data import DataLoader |
19 | 20 | from torchtitan.components.checkpoint import CheckpointManager |
@@ -165,7 +166,7 @@ def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None): |
165 | 166 | sd_to_save[key] = val |
166 | 167 | torch.save(sd_to_save, os.path.join(checkpoint_id, "state_dict.pt")) |
167 | 168 |
|
168 | | - def fake_load(self, states: dict, checkpoint_id=None): |
| 169 | + def fake_load(self, states: dict, checkpoint_id=None, **kwargs): |
169 | 170 | path = os.path.join(checkpoint_id, "state_dict.pt") |
170 | 171 | loaded = torch.load(path, weights_only="False") |
171 | 172 | for key, val in loaded.items(): |
@@ -748,7 +749,7 @@ def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None): |
748 | 749 | self.assertNotIn("optimizer", state_dict) |
749 | 750 | return |
750 | 751 |
|
751 | | - def fake_load(state_dict: dict, checkpoint_id=None): |
| 752 | + def fake_load(state_dict: dict, checkpoint_id=None, **kwargs): |
752 | 753 | self.assertIn("bias", state_dict) |
753 | 754 | self.assertIn("weight", state_dict) |
754 | 755 | # No model prefix |
@@ -776,5 +777,86 @@ def fake_load(state_dict: dict, checkpoint_id=None): |
776 | 777 | manager.load(step=1) |
777 | 778 |
|
778 | 779 |
|
| 780 | +class TestModelWrapperConverterKeys(unittest.TestCase): |
| 781 | + """Tests for ModelWrapper.has_converter_keys() and its effect on load planner.""" |
| 782 | + |
| 783 | + def _create_manager(self, mock_save, mock_load, model, temp_dir): |
| 784 | + """Create a CheckpointManager with mocked dcp.save/load.""" |
| 785 | + mock_save.side_effect = lambda *a, **kw: os.makedirs( |
| 786 | + kw.get("checkpoint_id", a[1] if len(a) > 1 else ""), exist_ok=True |
| 787 | + ) |
| 788 | + mock_load.side_effect = lambda *a, **kw: None |
| 789 | + |
| 790 | + cfg = CheckpointManager.Config( |
| 791 | + enable=True, |
| 792 | + async_mode="disabled", |
| 793 | + folder="", |
| 794 | + interval=1, |
| 795 | + keep_latest_k=0, |
| 796 | + last_save_model_only=False, |
| 797 | + export_dtype="float32", |
| 798 | + exclude_from_loading=[], |
| 799 | + initial_load_path=None, |
| 800 | + initial_load_model_only=False, |
| 801 | + ) |
| 802 | + with mock.patch("torch.distributed.new_group", return_value="pg"): |
| 803 | + return CheckpointManager( |
| 804 | + dataloader=FakeDataLoader(), |
| 805 | + model_parts=[model], |
| 806 | + optimizers=FakeOptimizersContainer(), |
| 807 | + lr_schedulers=FakeLRSchedulersContainer(), |
| 808 | + states={}, |
| 809 | + config=cfg, |
| 810 | + sd_adapter=None, |
| 811 | + base_folder=temp_dir, |
| 812 | + ft_manager=DummyFTManager(), |
| 813 | + ) |
| 814 | + |
| 815 | + @mock.patch("torch.distributed.get_rank", return_value=0) |
| 816 | + @mock.patch("torchtitan.components.checkpoint.dcp.load") |
| 817 | + @mock.patch("torchtitan.components.checkpoint.dcp.save") |
| 818 | + def test_load_uses_strict_planner_without_converter( |
| 819 | + self, mock_save, mock_load, mock_rank |
| 820 | + ): |
| 821 | + """Without converter keys, dcp.load is called with allow_partial_load=False.""" |
| 822 | + temp_dir = tempfile.mkdtemp() |
| 823 | + try: |
| 824 | + model = nn.Linear(2, 2) |
| 825 | + manager = self._create_manager(mock_save, mock_load, model, temp_dir) |
| 826 | + manager.save(curr_step=1) |
| 827 | + manager.load(step=1) |
| 828 | + |
| 829 | + _, kwargs = mock_load.call_args |
| 830 | + planner = kwargs.get("planner") |
| 831 | + self.assertIsInstance(planner, DefaultLoadPlanner) |
| 832 | + self.assertFalse(planner.allow_partial_load) |
| 833 | + finally: |
| 834 | + shutil.rmtree(temp_dir) |
| 835 | + |
| 836 | + @mock.patch("torch.distributed.get_rank", return_value=0) |
| 837 | + @mock.patch("torchtitan.components.checkpoint.dcp.load") |
| 838 | + @mock.patch("torchtitan.components.checkpoint.dcp.save") |
| 839 | + def test_load_uses_partial_planner_with_converter( |
| 840 | + self, mock_save, mock_load, mock_rank |
| 841 | + ): |
| 842 | + """With converter keys on the model, dcp.load is called with allow_partial_load=True.""" |
| 843 | + temp_dir = tempfile.mkdtemp() |
| 844 | + try: |
| 845 | + model = nn.Linear(2, 2) |
| 846 | + object.__setattr__( |
| 847 | + model, "converter_key_filter", lambda key: ".lora_a." in key |
| 848 | + ) |
| 849 | + manager = self._create_manager(mock_save, mock_load, model, temp_dir) |
| 850 | + manager.save(curr_step=1) |
| 851 | + manager.load(step=1) |
| 852 | + |
| 853 | + _, kwargs = mock_load.call_args |
| 854 | + planner = kwargs.get("planner") |
| 855 | + self.assertIsInstance(planner, DefaultLoadPlanner) |
| 856 | + self.assertTrue(planner.allow_partial_load) |
| 857 | + finally: |
| 858 | + shutil.rmtree(temp_dir) |
| 859 | + |
| 860 | + |
779 | 861 | if __name__ == "__main__": |
780 | 862 | unittest.main() |
0 commit comments