Skip to content

Commit 9984243

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
fix flaky distributed tests with barriers (#963)
Summary: Pull Request resolved: #963 Reviewed By: anshulverma, diego-urgell Differential Revision: D68466282 fbshipit-source-id: 836ba686237f243823410ae5b60242fedd705a62
1 parent edf6f85 commit 9984243

File tree

6 files changed

+10
-0
lines changed

6 files changed

+10
-0
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def _directory_sync_collective() -> None:
659659
tc.assertTrue("tmp" in dirpath)
660660
tc.assertFalse("foo" in dirpath)
661661
finally:
662+
dist.barrier() # avoid race condition
662663
if get_global_rank() == 0:
663664
shutil.rmtree(temp_dir) # delete temp directory
664665

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest.mock import MagicMock, patch
1818

1919
import torch
20+
import torch.distributed as dist
2021
from pyre_extensions import none_throws
2122
from torch import nn
2223
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
@@ -309,6 +310,7 @@ def _save_restore_ddp() -> None:
309310
tc, my_new_unit.module.state_dict(), my_unit.module.state_dict()
310311
)
311312
finally:
313+
dist.barrier() # avoid race condition
312314
if get_global_rank() == 0:
313315
shutil.rmtree(temp_dir) # delete temp directory
314316

tests/framework/callbacks/test_dcp_saver_gpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _save_restore_fsdp() -> None:
116116
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
117117
)
118118
finally:
119+
dist.barrier() # avoid race condition
119120
if get_global_rank() == 0:
120121
shutil.rmtree(temp_dir) # delete temp directory
121122

@@ -165,5 +166,6 @@ def _save_restore_fsdp_with_id() -> None:
165166
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
166167
)
167168
finally:
169+
dist.barrier() # avoid race condition
168170
if get_global_rank() == 0:
169171
shutil.rmtree(temp_dir) # delete temp directory

tests/framework/callbacks/test_torchsnapshot_saver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest.mock import MagicMock, patch
1818

1919
import torch
20+
import torch.distributed as dist
2021
from torch import nn
2122
from torch.utils.data import DataLoader
2223
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
@@ -322,6 +323,7 @@ def _save_restore_ddp() -> None:
322323
tc, my_new_unit.module.state_dict(), my_unit.module.state_dict()
323324
)
324325
finally:
326+
dist.barrier() # avoid race condition
325327
if get_global_rank() == 0:
326328
shutil.rmtree(temp_dir) # delete temp directory
327329

tests/framework/callbacks/test_torchsnapshot_saver_gpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import unittest
1414

1515
import torch
16+
import torch.distributed as dist
1617
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
1718
from torchtnt.framework.callbacks.torchsnapshot_saver import TorchSnapshotSaver
1819
from torchtnt.framework.train import train
@@ -68,5 +69,6 @@ def _save_restore_fsdp() -> None:
6869
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
6970
)
7071
finally:
72+
dist.barrier() # avoid race condition
7173
if get_global_rank() == 0:
7274
shutil.rmtree(temp_dir) # delete temp directory

tests/utils/test_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,7 @@ def create_tmp_dir() -> str:
13241324
get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), []
13251325
)
13261326
finally:
1327+
dist.barrier() # avoid race condition
13271328
if get_global_rank() == 0:
13281329
shutil.rmtree(temp_dir) # delete temp directory
13291330

0 commit comments

Comments
 (0)