Skip to content

Commit 3345ed9

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Recover logging dead code in DCP _async_save (#889)
Summary: Pull Request resolved: #889 Reviewed By: JKSenthil Differential Revision: D61739504 fbshipit-source-id: 1f71578ffd7c29dcbb1b74d8e051e13d5d458fba
1 parent a6d3d91 commit 3345ed9

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.distributed as dist
17+
from pyre_extensions import none_throws
1718
from torch.distributed import checkpoint as dcp
1819
from torch.distributed.checkpoint.default_planner import (
1920
DefaultLoadPlanner,
@@ -160,7 +161,7 @@ def _checkpoint_impl(
160161
checkpoint_id, app_state, planner, storage_writer
161162
)
162163
if curr_snapshot_wait:
163-
self._wait()
164+
self._wait(log_warning=False)
164165
else:
165166
with get_timing_context(state, f"{self.__class__.__name__}.save"):
166167
checkpoint_success = self._save(
@@ -169,9 +170,42 @@ def _checkpoint_impl(
169170

170171
return checkpoint_success
171172

172-
def _wait(self) -> None:
173-
if self._prev_snapshot is not None:
174-
self._prev_snapshot.result()
173+
def _wait(self, log_warning: bool = True) -> None:
174+
"""
175+
If the previous async checkpoint is still running, wait for it to finish before continuing. Otherwise,
176+
distributed collectives that use the checkpointing process group will result in a stuck job. This also
177+
computes and logs the time spent waiting on the previous checkpoint to finish, and a toggable warning
178+
for the user to modify checkpointing frequency.
179+
180+
If the previous checkpoing has already finished, this is a no-op.
181+
182+
Args:
183+
log_warning: Toggle for logging a warning to the user to modify checkpointing frequency. Sometimes
184+
this is not up to the user (e.g. on_exception, on_train_end).
185+
"""
186+
if self._prev_snapshot is None:
187+
return
188+
189+
if self._prev_snapshot.done():
190+
none_throws(self._prev_snapshot).result()
191+
return
192+
193+
if log_warning:
194+
rank_zero_warn(
195+
(
196+
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
197+
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
198+
),
199+
logger=logger,
200+
)
201+
202+
t0 = time.monotonic()
203+
none_throws(self._prev_snapshot).result()
204+
205+
rank_zero_warn(
206+
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
207+
logger=logger,
208+
)
175209

176210
def _async_save(
177211
self,
@@ -187,24 +221,8 @@ def _async_save(
187221
if storage_writer is None:
188222
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
189223

190-
if self._prev_snapshot is not None:
191-
if not self._prev_snapshot.done():
192-
# TODO this is unreachable at this point, since we are waiting on other functions called before _checkpoint_impl.
193-
rank_zero_warn(
194-
(
195-
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
196-
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
197-
),
198-
logger=logger,
199-
)
200-
t0 = time.monotonic()
201-
self._wait()
202-
rank_zero_warn(
203-
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
204-
logger=logger,
205-
)
206-
else:
207-
self._wait()
224+
# Redundant check for safety
225+
self._wait(log_warning=True)
208226

209227
self._prev_snapshot = dcp.async_save(
210228
state_dict={"app_state": MultiStateful(app_state)},
@@ -257,7 +275,8 @@ def on_exception(
257275
unit: Union[TTrainUnit, TEvalUnit, TPredictUnit],
258276
exc: BaseException,
259277
) -> None:
260-
self._wait()
278+
rank_zero_info("Ensuring previous async checkpoint finished before exiting.")
279+
self._wait(log_warning=False)
261280

262281
@staticmethod
263282
def restore(
@@ -404,6 +423,7 @@ def _generate_checkpoint_and_upkeep(
404423
# operations in the base class use the process group. So wait here instead.
405424
self._wait()
406425

426+
# Note that every async checkpoint will be completed at this point.
407427
return super()._generate_checkpoint_and_upkeep(state, unit, hook)
408428

409429
@property

0 commit comments

Comments
 (0)