Skip to content

Commit 86b2598

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
sync entire dirpath prior to checkpointing (#867)
Summary: Pull Request resolved: #867 Reviewed By: diego-urgell Differential Revision: D59929876 fbshipit-source-id: 542eee83b2fd69f55e14234657fdb6a8e61dc2d7
1 parent 3a0bf42 commit 86b2598

File tree

4 files changed

+74
-1
lines changed

4 files changed

+74
-1
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,48 @@ def test_multi_phase_e2e(self) -> None:
10451045
min_optim, [str(x) for x in metric_cb._checkpoint_manager._ckpt_paths]
10461046
)
10471047

1048+
@skip_if_not_distributed
1049+
def test_directory_path_synced(self) -> None:
1050+
spawn_multi_process(
1051+
2,
1052+
"gloo",
1053+
self._test_directory_path_synced,
1054+
)
1055+
1056+
@staticmethod
1057+
def _test_directory_path_synced() -> None:
1058+
init_from_env()
1059+
tc = unittest.TestCase()
1060+
1061+
temp_dir = tempfile.mkdtemp() if get_global_rank() == 0 else ""
1062+
bcs = BaseCheckpointSaver(
1063+
temp_dir,
1064+
save_every_n_epochs=1,
1065+
)
1066+
1067+
try:
1068+
state = get_dummy_train_state()
1069+
my_train_unit = MyTrainLossUnit()
1070+
1071+
if dist.get_rank() == 0:
1072+
my_train_unit.train_progress._num_epochs_completed = 10
1073+
else:
1074+
my_train_unit.train_progress._num_epochs_completed = 3
1075+
1076+
bcs.on_train_epoch_end(state, my_train_unit)
1077+
tc.assertEqual(len(bcs._checkpoint_manager._ckpt_paths), 1)
1078+
tc.assertEqual(
1079+
str(bcs._checkpoint_manager._ckpt_paths[0]),
1080+
os.path.join(bcs.dirpath, "epoch_10_train_step_0"),
1081+
)
1082+
tc.assertEqual(
1083+
os.listdir(bcs.dirpath),
1084+
["epoch_10_train_step_0"],
1085+
)
1086+
finally:
1087+
if get_global_rank() == 0:
1088+
shutil.rmtree(temp_dir) # delete temp directory
1089+
10481090

10491091
class MyValLossUnit(TrainUnit[Batch]):
10501092
def __init__(self) -> None:

tests/utils/test_checkpoint.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,32 @@ def test_generate_checkpoint_path(self) -> None:
608608
):
609609
ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5))
610610

611+
@skip_if_not_distributed
612+
def test_generate_checkpoint_path_distributed(self) -> None:
613+
spawn_multi_process(
614+
world_size=2,
615+
backend="gloo",
616+
method=self._test_generate_checkpoint_path_distributed,
617+
)
618+
619+
@staticmethod
620+
def _test_generate_checkpoint_path_distributed() -> None:
621+
tc = unittest.TestCase()
622+
623+
init_from_env()
624+
625+
ckpt_manager = CheckpointManager("foo")
626+
627+
if dist.get_rank() == 0:
628+
path = ckpt_manager.generate_checkpoint_path(1, 1).path
629+
else:
630+
path = ckpt_manager.generate_checkpoint_path(3, 41).path
631+
632+
tc.assertEqual(
633+
path,
634+
"foo/epoch_1_step_1",
635+
)
636+
611637
def test_append_checkpoint_by_recency(self) -> None:
612638
ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=3)
613639
ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)]

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ def _generate_checkpoint_and_upkeep(
171171
)
172172

173173
checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
174-
epoch, step_mapping, metric_data
174+
epoch,
175+
step_mapping,
176+
metric_data,
177+
process_group=self._process_group,
175178
)
176179

177180
# 2) Determine if we should save checkpoint

torchtnt/utils/checkpoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,13 @@ def prune_surplus_checkpoints(self) -> None:
400400
for _ in range(len(self._ckpt_paths) - keep_last_n_checkpoints):
401401
self.remove_checkpoint()
402402

403+
@rank_zero_read_and_broadcast
403404
def generate_checkpoint_path(
404405
self,
405406
epoch: int,
406407
step: Union[int, Dict[Phase, int]],
407408
metric_data: Optional[MetricData] = None,
409+
process_group: Optional[dist.ProcessGroup] = None,
408410
) -> CheckpointPath:
409411
"""
410412
Given the current epoch, step, and possibly a metric_data value, determine the path

0 commit comments

Comments
 (0)