Skip to content

Commit e4e7a9d

Browse files
galrotemfacebook-github-bot
authored andcommitted
dcp checkpointer - ensure no distributed collectives while checkpoint is ongoing (#870)
Summary: Pull Request resolved: #870 Reviewed By: saumishr, diego-urgell Differential Revision: D60174864 fbshipit-source-id: 69d15c0c889b766aae540c8dfeb5642d7e3ea339
1 parent 745f5cb commit e4e7a9d

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import abc
1010
import logging
1111
from datetime import timedelta
12-
from typing import Any, cast, Iterable, Literal, Optional, Union
12+
from typing import Any, cast, Dict, Iterable, Literal, Optional, Union
1313

1414
import torch.distributed as dist
1515
from pyre_extensions import none_throws
@@ -170,7 +170,7 @@ def _generate_checkpoint_and_upkeep(
170170
value=metric_value,
171171
)
172172

173-
checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
173+
checkpoint_path = self._generate_checkpoint_path(
174174
epoch,
175175
step_mapping,
176176
metric_data,
@@ -225,6 +225,20 @@ def _does_checkpoint_exist(
225225
checkpoint_path, process_group
226226
)
227227

228+
def _generate_checkpoint_path(
229+
self,
230+
epoch: int,
231+
step_mapping: Union[int, Dict[Phase, int]],
232+
metric_data: Optional[MetricData] = None,
233+
process_group: Optional[dist.ProcessGroup] = None,
234+
) -> CheckpointPath:
235+
return self._checkpoint_manager.generate_checkpoint_path(
236+
epoch,
237+
step_mapping,
238+
metric_data,
239+
process_group=process_group,
240+
)
241+
228242
def _get_tracked_metric_value(
229243
self, unit: Union[TTrainUnit, TEvalUnit]
230244
) -> Optional[float]:

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
TTrainUnit,
3939
)
4040
from torchtnt.framework.utils import get_timing_context
41-
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
41+
from torchtnt.utils.checkpoint import (
42+
BestCheckpointConfig,
43+
CheckpointPath,
44+
MetricData,
45+
Phase,
46+
)
4247
from torchtnt.utils.optimizer import init_optim_state
4348
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
4449
from torchtnt.utils.stateful import MultiStateful, Stateful
@@ -385,6 +390,24 @@ def _does_checkpoint_exist(
385390
checkpoint_path=checkpoint_path, process_group=process_group
386391
)
387392

393+
def _generate_checkpoint_path(
394+
self,
395+
epoch: int,
396+
step_mapping: Union[int, Dict[Phase, int]],
397+
metric_data: Optional[MetricData] = None,
398+
process_group: Optional[dist.ProcessGroup] = None,
399+
) -> CheckpointPath:
400+
# if we are still checkpointing, this might cause a collective hang.
401+
# so wait here instead
402+
self._wait()
403+
404+
return super()._generate_checkpoint_path(
405+
epoch=epoch,
406+
step_mapping=step_mapping,
407+
metric_data=metric_data,
408+
process_group=process_group,
409+
)
410+
388411
@property
389412
def default_writer_options(self) -> Dict[str, Any]:
390413
# defaults are picked to to match TSS defaults

0 commit comments

Comments
 (0)