Skip to content

Commit d1695d2

Browse files
Saiteja64facebook-github-bot
authored andcommitted
OSS Zero Overhead Checkpointing Implementation (#1010)
Summary: Pull Request resolved: #1010 X-link: pytorch/pytorch#156207 This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing Reviewed By: diego-urgell Differential Revision: D72391401 fbshipit-source-id: 3e872e0faa9558a61e6bcf9026dab001f4f52f0b
1 parent fa02938 commit d1695d2

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,14 @@ def _checkpoint_impl(
193193
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
194194
# Redundant check for safety
195195
self._wait(log_warning=True)
196-
self._prev_snapshot = dcp.async_save(
196+
prev_snapshot = dcp.async_save(
197197
state_dict={"app_state": MultiStateful(app_state)},
198198
checkpoint_id=checkpoint_id,
199199
process_group=self._process_group,
200200
storage_writer=storage_writer,
201201
planner=planner,
202202
)
203+
self._prev_snapshot = cast(Future, prev_snapshot)
203204
if curr_snapshot_wait:
204205
self._wait(log_warning=False)
205206
else:

torchtnt/utils/prepare_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import (
1414
Any,
1515
Callable,
16-
cast,
1716
Collection,
1817
ContextManager,
1918
Dict,

0 commit comments

Comments
 (0)