From 7be41bf77551aac660c8109067e51fff8ebd17b5 Mon Sep 17 00:00:00 2001 From: Saiteja Samudrala Date: Tue, 5 Aug 2025 23:50:30 -0700 Subject: [PATCH] Split up Simple Staging and Legacy Staging Logic (#1022) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/1022 Currently base_async_storage_writer is quite complicated and has a bunch of branching between simple_staging and legacy staging. We are looking to introduce yet another simple staging solution. So refactor as follows: Last half, we introduced stager as an argument to DCP async_save. 1. Introduce LegacyStager and move legacy staging logic intop it 2. Introduce SimpleStager and move simple staging logic into it. 3. Based on JKs in checkpoint_dist_client initialize the right stager to use. 4. Remove Staging Logic From BaseAsyncStorageWriter Differential Revision: D79434609 --- torchtnt/framework/callbacks/dcp_saver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 2629fb6c96..6b7e1ec1f9 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -26,6 +26,7 @@ DefaultSavePlanner, ) from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner +from torch.distributed.checkpoint.staging import AsyncStager from torchtnt.framework.state import EntryPoint try: @@ -161,6 +162,7 @@ def _checkpoint_impl( hook: str, planner: Optional[SavePlanner] = None, storage_writer: Optional[StorageWriter] = None, + stager: Optional[AsyncStager] = None, ) -> bool: if hook not in [ "on_train_step_end", @@ -199,6 +201,7 @@ def _checkpoint_impl( process_group=self._process_group, storage_writer=storage_writer, planner=planner, + async_stager=stager, ) self._prev_snapshot = cast(Future, prev_snapshot) if curr_snapshot_wait: