Skip to content

Commit a6d3d91

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Simplify async wait hooks in DCPSaver (#888)
Summary: Pull Request resolved: #888 Reviewed By: anshulverma Differential Revision: D61739503 fbshipit-source-id: a0e570f28769e098aa82634c559fd5e3169f551a
1 parent 3d5376e commit a6d3d91

File tree

3 files changed

+13
-62
lines changed

3 files changed

+13
-62
lines changed

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import abc
1010
import logging
1111
from datetime import timedelta
12-
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union
12+
from typing import Any, cast, Iterable, List, Literal, Optional, Union
1313

1414
import torch.distributed as dist
1515
from pyre_extensions import none_throws
@@ -21,7 +21,6 @@
2121
from torchtnt.utils.checkpoint import (
2222
BestCheckpointConfig,
2323
CheckpointManager,
24-
CheckpointPath,
2524
get_best_checkpoint_path,
2625
get_latest_checkpoint_path,
2726
MetricData,
@@ -172,7 +171,7 @@ def _generate_checkpoint_and_upkeep(
172171
value=metric_value,
173172
)
174173

175-
checkpoint_path = self._generate_checkpoint_path(
174+
checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
176175
epoch,
177176
step_mapping,
178177
metric_data,
@@ -185,7 +184,9 @@ def _generate_checkpoint_and_upkeep(
185184

186185
if hook == "on_train_end":
187186
# 2.1) Make sure that last checkpoint does not already exist
188-
if self._does_checkpoint_exist(checkpoint_path, self._process_group):
187+
if self._checkpoint_manager.does_checkpoint_exist(
188+
checkpoint_path, self._process_group
189+
):
189190
rank_zero_warn(
190191
"Final checkpoint already exists, skipping.", logger=logger
191192
)
@@ -220,30 +221,6 @@ def _generate_checkpoint_and_upkeep(
220221

221222
return True
222223

223-
def _does_checkpoint_exist(
224-
self,
225-
checkpoint_path: CheckpointPath,
226-
process_group: Optional[dist.ProcessGroup] = None,
227-
) -> bool:
228-
# Only keep this function as a hook for downstream checkpointer
229-
return self._checkpoint_manager.does_checkpoint_exist(
230-
checkpoint_path, process_group
231-
)
232-
233-
def _generate_checkpoint_path(
234-
self,
235-
epoch: int,
236-
step_mapping: Union[int, Dict[Phase, int]],
237-
metric_data: Optional[MetricData] = None,
238-
process_group: Optional[dist.ProcessGroup] = None,
239-
) -> CheckpointPath:
240-
return self._checkpoint_manager.generate_checkpoint_path(
241-
epoch,
242-
step_mapping,
243-
metric_data,
244-
process_group=process_group,
245-
)
246-
247224
def _get_tracked_metric_value(
248225
self, unit: Union[TTrainUnit, TEvalUnit]
249226
) -> Optional[float]:

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,7 @@
3939
TTrainUnit,
4040
)
4141
from torchtnt.framework.utils import get_timing_context
42-
from torchtnt.utils.checkpoint import (
43-
BestCheckpointConfig,
44-
CheckpointPath,
45-
MetricData,
46-
Phase,
47-
)
42+
from torchtnt.utils.checkpoint import BestCheckpointConfig
4843
from torchtnt.utils.optimizer import init_optim_state
4944
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
5045
from torchtnt.utils.stateful import MultiStateful, Stateful
@@ -194,6 +189,7 @@ def _async_save(
194189

195190
if self._prev_snapshot is not None:
196191
if not self._prev_snapshot.done():
192+
# TODO this is unreachable at this point, since we are waiting on other functions called before _checkpoint_impl.
197193
rank_zero_warn(
198194
(
199195
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
@@ -401,36 +397,14 @@ def restore_with_id(
401397
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
402398
)
403399

404-
def _does_checkpoint_exist(
405-
self,
406-
checkpoint_path: CheckpointPath,
407-
process_group: Optional[dist.ProcessGroup] = None,
400+
def _generate_checkpoint_and_upkeep(
401+
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
408402
) -> bool:
409-
# if we are still checkpointing, this might cause a collective hang.
410-
# so wait here instead
403+
# if we are still checkpointing, this might cause a collective hang, since several
404+
# operations in the base class use the process group. So wait here instead.
411405
self._wait()
412406

413-
return super()._does_checkpoint_exist(
414-
checkpoint_path=checkpoint_path, process_group=process_group
415-
)
416-
417-
def _generate_checkpoint_path(
418-
self,
419-
epoch: int,
420-
step_mapping: Union[int, Dict[Phase, int]],
421-
metric_data: Optional[MetricData] = None,
422-
process_group: Optional[dist.ProcessGroup] = None,
423-
) -> CheckpointPath:
424-
# if we are still checkpointing, this might cause a collective hang.
425-
# so wait here instead
426-
self._wait()
427-
428-
return super()._generate_checkpoint_path(
429-
epoch=epoch,
430-
step_mapping=step_mapping,
431-
metric_data=metric_data,
432-
process_group=process_group,
433-
)
407+
return super()._generate_checkpoint_and_upkeep(state, unit, hook)
434408

435409
@property
436410
def default_writer_options(self) -> Dict[str, Any]:

torchtnt/utils/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _is_phase_aware(self) -> bool:
206206
def newer_than(self, other: "CheckpointPath") -> bool:
207207
"""
208208
Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other.
209-
Note that recency is determine in terms of the epoch, phase, and number of steps. It is NOT related to
209+
Note that recency is determined in terms of the epoch, phase, and number of steps. It is NOT related to
210210
the timestamp the checkpoint was saved.
211211
212212
Returns:

0 commit comments

Comments
 (0)