Skip to content

Commit d3262b8

Browse files
Saiteja64facebook-github-bot
authored andcommitted
Split up Simple Staging and Legacy Staging Logic
Summary: Currently base_async_storage_writer is quite complicated and has a bunch of branching between simple_staging and legacy staging. We are looking to introduce yet another simple staging solution. So refactor as follows: Last half, we introduced stager as an argument to DCP async_save. 1. Introduce LegacyStager and move legacy staging logic intop it 2. Introduce SimpleStager and move simple staging logic into it. 3. Based on JKs in checkpoint_dist_client initialize the right stager to use. 4. Remove Staging Logic From BaseAsyncStorageWriter Differential Revision: D79434609
1 parent b285385 commit d3262b8

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DefaultSavePlanner,
2727
)
2828
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
29+
from torch.distributed.checkpoint.staging import AsyncStager
2930
from torchtnt.framework.state import EntryPoint
3031

3132
try:
@@ -161,6 +162,7 @@ def _checkpoint_impl(
161162
hook: str,
162163
planner: Optional[SavePlanner] = None,
163164
storage_writer: Optional[StorageWriter] = None,
165+
stager: Optional[AsyncStager] = None,
164166
) -> bool:
165167
if hook not in [
166168
"on_train_step_end",
@@ -199,6 +201,7 @@ def _checkpoint_impl(
199201
process_group=self._process_group,
200202
storage_writer=storage_writer,
201203
planner=planner,
204+
async_stager=stager,
202205
)
203206
self._prev_snapshot = cast(Future, prev_snapshot)
204207
if curr_snapshot_wait:

0 commit comments

Comments
 (0)