|
20 | 20 | from torch import nn
|
21 | 21 | from torch.utils.data import DataLoader
|
22 | 22 | from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
|
23 |
| - |
24 | 23 | from torchtnt.framework._test_utils import (
|
25 | 24 | DummyAutoUnit,
|
26 | 25 | DummyTrainUnit,
|
|
31 | 30 | from torchtnt.framework.train import train
|
32 | 31 | from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
|
33 | 32 | from torchtnt.utils.env import seed
|
34 |
| -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu |
| 33 | +from torchtnt.utils.test_utils import skip_if_not_distributed |
35 | 34 |
|
36 | 35 |
|
37 | 36 | class DistributedCheckpointSaverTest(unittest.TestCase):
|
@@ -222,55 +221,6 @@ def test_save_restore_no_lr_scheduler_restore(
|
222 | 221 | app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
|
223 | 222 | self.assertIn("lr_scheduler", app_state)
|
224 | 223 |
|
225 |
| - @skip_if_not_distributed |
226 |
| - @skip_if_not_gpu |
227 |
| - def test_save_restore_fsdp(self) -> None: |
228 |
| - spawn_multi_process( |
229 |
| - 2, |
230 |
| - "nccl", |
231 |
| - self._save_restore_fsdp, |
232 |
| - ) |
233 |
| - |
234 |
| - @staticmethod |
235 |
| - def _save_restore_fsdp() -> None: |
236 |
| - input_dim = 2 |
237 |
| - dataset_len = 10 |
238 |
| - batch_size = 2 |
239 |
| - max_epochs = 2 |
240 |
| - save_every_n_epochs = 1 |
241 |
| - |
242 |
| - my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp") |
243 |
| - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) |
244 |
| - if get_global_rank() == 0: |
245 |
| - temp_dir = tempfile.mkdtemp() |
246 |
| - else: |
247 |
| - temp_dir = "" |
248 |
| - |
249 |
| - dcp_cb = DistributedCheckpointSaver( |
250 |
| - temp_dir, |
251 |
| - save_every_n_epochs=save_every_n_epochs, |
252 |
| - ) |
253 |
| - temp_dir = dcp_cb.dirpath |
254 |
| - train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb]) |
255 |
| - |
256 |
| - tc = unittest.TestCase() |
257 |
| - try: |
258 |
| - my_new_unit = DummyAutoUnit( |
259 |
| - module=torch.nn.Linear(input_dim, 2), strategy="fsdp" |
260 |
| - ) |
261 |
| - tc.assertNotEqual( |
262 |
| - my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() |
263 |
| - ) |
264 |
| - # get latest checkpoint |
265 |
| - ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10") |
266 |
| - dcp_cb.restore(ckpt_path, my_new_unit) |
267 |
| - tc.assertEqual( |
268 |
| - my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() |
269 |
| - ) |
270 |
| - finally: |
271 |
| - if get_global_rank() == 0: |
272 |
| - shutil.rmtree(temp_dir) # delete temp directory |
273 |
| - |
274 | 224 | @skip_if_not_distributed
|
275 | 225 | def test_save_restore_ddp(self) -> None:
|
276 | 226 | spawn_multi_process(
|
|
0 commit comments