Skip to content

Commit b27d916

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move torchsnapshot_saver GPU test to dedicate file (#760)
Summary: Pull Request resolved: #760 Reviewed By: galrotem Differential Revision: D55327868 fbshipit-source-id: 45e1dae2dc7ee1304a01cfe3a5b9a102dec02e15
1 parent 249eea3 commit b27d916

File tree

2 files changed

+73
-51
lines changed

2 files changed

+73
-51
lines changed

tests/framework/callbacks/test_torchsnapshot_saver.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from torchtnt.framework.train import train
3737
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
3838
from torchtnt.utils.env import seed
39-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
39+
from torchtnt.utils.test_utils import skip_if_not_distributed
4040

4141

4242
class TorchSnapshotSaverTest(unittest.TestCase):
@@ -227,56 +227,6 @@ def test_save_restore_no_lr_scheduler_restore(
227227
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
228228
self.assertIn("lr_scheduler", app_state)
229229

230-
@skip_if_not_distributed
231-
@skip_if_not_gpu
232-
def test_save_restore_fsdp(self) -> None:
233-
spawn_multi_process(
234-
2,
235-
"nccl",
236-
self._save_restore_fsdp,
237-
)
238-
239-
@staticmethod
240-
def _save_restore_fsdp() -> None:
241-
input_dim = 2
242-
dataset_len = 10
243-
batch_size = 2
244-
max_epochs = 2
245-
save_every_n_epochs = 1
246-
247-
my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
248-
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
249-
if get_global_rank() == 0:
250-
temp_dir = tempfile.mkdtemp()
251-
else:
252-
temp_dir = ""
253-
254-
snapshot_cb = TorchSnapshotSaver(
255-
temp_dir,
256-
save_every_n_epochs=save_every_n_epochs,
257-
replicated=["**"],
258-
)
259-
temp_dir = snapshot_cb.dirpath
260-
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
261-
262-
tc = unittest.TestCase()
263-
try:
264-
my_new_unit = DummyAutoUnit(
265-
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
266-
)
267-
tc.assertNotEqual(
268-
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
269-
)
270-
# get latest checkpoint
271-
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
272-
snapshot_cb.restore(ckpt_path, my_new_unit)
273-
tc.assertEqual(
274-
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
275-
)
276-
finally:
277-
if get_global_rank() == 0:
278-
shutil.rmtree(temp_dir) # delete temp directory
279-
280230
@skip_if_not_distributed
281231
def test_save_restore_ddp(self) -> None:
282232
spawn_multi_process(
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import os
11+
import shutil
12+
import tempfile
13+
import unittest
14+
15+
import torch
16+
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
17+
from torchtnt.framework.callbacks.torchsnapshot_saver import TorchSnapshotSaver
18+
from torchtnt.framework.train import train
19+
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
20+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
21+
22+
23+
class TorchSnapshotSaverGPUTest(unittest.TestCase):
24+
@skip_if_not_distributed
25+
@skip_if_not_gpu
26+
def test_save_restore_fsdp(self) -> None:
27+
spawn_multi_process(
28+
2,
29+
"nccl",
30+
self._save_restore_fsdp,
31+
)
32+
33+
@staticmethod
34+
def _save_restore_fsdp() -> None:
35+
input_dim = 2
36+
dataset_len = 10
37+
batch_size = 2
38+
max_epochs = 2
39+
save_every_n_epochs = 1
40+
41+
my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
42+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
43+
if get_global_rank() == 0:
44+
temp_dir = tempfile.mkdtemp()
45+
else:
46+
temp_dir = ""
47+
48+
snapshot_cb = TorchSnapshotSaver(
49+
temp_dir,
50+
save_every_n_epochs=save_every_n_epochs,
51+
replicated=["**"],
52+
)
53+
temp_dir = snapshot_cb.dirpath
54+
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
55+
56+
tc = unittest.TestCase()
57+
try:
58+
my_new_unit = DummyAutoUnit(
59+
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
60+
)
61+
tc.assertNotEqual(
62+
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
63+
)
64+
# get latest checkpoint
65+
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
66+
snapshot_cb.restore(ckpt_path, my_new_unit)
67+
tc.assertEqual(
68+
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
69+
)
70+
finally:
71+
if get_global_rank() == 0:
72+
shutil.rmtree(temp_dir) # delete temp directory

0 commit comments

Comments
 (0)