Skip to content

Commit bc2bf15

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move DCP Saver GPU test to a dedicated test file (#752)
Summary: Pull Request resolved: #752 Reviewed By: JKSenthil Differential Revision: D55151767 fbshipit-source-id: 127c2b2e4c5a5b086ad45534d1c6ef2f97493c34
1 parent e6b7933 commit bc2bf15

File tree

2 files changed

+73
-51
lines changed

2 files changed

+73
-51
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch import nn
2121
from torch.utils.data import DataLoader
2222
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
23-
2423
from torchtnt.framework._test_utils import (
2524
DummyAutoUnit,
2625
DummyTrainUnit,
@@ -31,7 +30,7 @@
3130
from torchtnt.framework.train import train
3231
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
3332
from torchtnt.utils.env import seed
34-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
33+
from torchtnt.utils.test_utils import skip_if_not_distributed
3534

3635

3736
class DistributedCheckpointSaverTest(unittest.TestCase):
@@ -222,55 +221,6 @@ def test_save_restore_no_lr_scheduler_restore(
222221
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
223222
self.assertIn("lr_scheduler", app_state)
224223

225-
@skip_if_not_distributed
226-
@skip_if_not_gpu
227-
def test_save_restore_fsdp(self) -> None:
228-
spawn_multi_process(
229-
2,
230-
"nccl",
231-
self._save_restore_fsdp,
232-
)
233-
234-
@staticmethod
235-
def _save_restore_fsdp() -> None:
236-
input_dim = 2
237-
dataset_len = 10
238-
batch_size = 2
239-
max_epochs = 2
240-
save_every_n_epochs = 1
241-
242-
my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
243-
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
244-
if get_global_rank() == 0:
245-
temp_dir = tempfile.mkdtemp()
246-
else:
247-
temp_dir = ""
248-
249-
dcp_cb = DistributedCheckpointSaver(
250-
temp_dir,
251-
save_every_n_epochs=save_every_n_epochs,
252-
)
253-
temp_dir = dcp_cb.dirpath
254-
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
255-
256-
tc = unittest.TestCase()
257-
try:
258-
my_new_unit = DummyAutoUnit(
259-
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
260-
)
261-
tc.assertNotEqual(
262-
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
263-
)
264-
# get latest checkpoint
265-
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
266-
dcp_cb.restore(ckpt_path, my_new_unit)
267-
tc.assertEqual(
268-
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
269-
)
270-
finally:
271-
if get_global_rank() == 0:
272-
shutil.rmtree(temp_dir) # delete temp directory
273-
274224
@skip_if_not_distributed
275225
def test_save_restore_ddp(self) -> None:
276226
spawn_multi_process(
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import os
11+
import shutil
12+
import tempfile
13+
import unittest
14+
15+
import torch
16+
17+
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
18+
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
19+
from torchtnt.framework.train import train
20+
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
21+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
22+
23+
24+
class DistributedCheckpointSaverGPUTest(unittest.TestCase):
25+
@skip_if_not_distributed
26+
@skip_if_not_gpu
27+
def test_save_restore_fsdp(self) -> None:
28+
spawn_multi_process(
29+
2,
30+
"nccl",
31+
self._save_restore_fsdp,
32+
)
33+
34+
@staticmethod
35+
def _save_restore_fsdp() -> None:
36+
input_dim = 2
37+
dataset_len = 10
38+
batch_size = 2
39+
max_epochs = 2
40+
save_every_n_epochs = 1
41+
42+
my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
43+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
44+
if get_global_rank() == 0:
45+
temp_dir = tempfile.mkdtemp()
46+
else:
47+
temp_dir = ""
48+
49+
dcp_cb = DistributedCheckpointSaver(
50+
temp_dir,
51+
save_every_n_epochs=save_every_n_epochs,
52+
)
53+
temp_dir = dcp_cb.dirpath
54+
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
55+
56+
tc = unittest.TestCase()
57+
try:
58+
my_new_unit = DummyAutoUnit(
59+
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
60+
)
61+
tc.assertNotEqual(
62+
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
63+
)
64+
# get latest checkpoint
65+
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
66+
dcp_cb.restore(ckpt_path, my_new_unit)
67+
tc.assertEqual(
68+
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
69+
)
70+
finally:
71+
if get_global_rank() == 0:
72+
shutil.rmtree(temp_dir) # delete temp directory

0 commit comments

Comments
 (0)