diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 2629fb6c96..6b7e1ec1f9 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -26,6 +26,7 @@ DefaultSavePlanner, ) from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner +from torch.distributed.checkpoint.staging import AsyncStager from torchtnt.framework.state import EntryPoint try: @@ -161,6 +162,7 @@ def _checkpoint_impl( hook: str, planner: Optional[SavePlanner] = None, storage_writer: Optional[StorageWriter] = None, + stager: Optional[AsyncStager] = None, ) -> bool: if hook not in [ "on_train_step_end", @@ -199,6 +201,7 @@ def _checkpoint_impl( process_group=self._process_group, storage_writer=storage_writer, planner=planner, + async_stager=stager, ) self._prev_snapshot = cast(Future, prev_snapshot) if curr_snapshot_wait: