14
14
import torch
15
15
import torch .distributed as dist
16
16
from torch .distributed import checkpoint as dcp
17
+ from torch .distributed .checkpoint .default_planner import DefaultSavePlanner
18
+ from torch .distributed .checkpoint .planner import SavePlanner
19
+ from torch .distributed .checkpoint .storage import StorageWriter
17
20
18
21
from torchtnt .framework .callbacks ._checkpoint_utils import (
19
22
_prepare_app_state_for_checkpoint ,
@@ -127,6 +130,8 @@ def _checkpoint_impl(
127
130
* ,
128
131
checkpoint_path : str ,
129
132
hook : str ,
133
+ planner : Optional [SavePlanner ] = None ,
134
+ storage_writer : Optional [StorageWriter ] = None ,
130
135
) -> bool :
131
136
if hook not in ["on_train_step_end" , "on_train_epoch_end" , "on_train_end" ]:
132
137
raise RuntimeError (f"Unexpected hook encountered '{ hook } '" )
@@ -142,20 +147,36 @@ def _checkpoint_impl(
142
147
# since this is async checkpointed, so in
143
148
# future, add logic to set successful flag
144
149
# only when checkpoint is fully written
145
- checkpoint_success = self ._async_save (checkpoint_path , app_state )
150
+ checkpoint_success = self ._async_save (
151
+ checkpoint_path , app_state , planner , storage_writer
152
+ )
146
153
if curr_snapshot_wait :
147
154
self ._wait ()
148
155
else :
149
156
with get_timing_context (state , f"{ self .__class__ .__name__ } .save" ):
150
- checkpoint_success = self ._save (checkpoint_path , app_state )
157
+ checkpoint_success = self ._save (
158
+ checkpoint_path , app_state , planner , storage_writer
159
+ )
151
160
152
161
return checkpoint_success
153
162
154
163
def _wait (self ) -> None :
155
164
if self ._prev_snapshot is not None :
156
165
self ._prev_snapshot .result ()
157
166
158
- def _async_save (self , checkpoint_id : str , app_state : Dict [str , Stateful ]) -> bool :
167
+ def _async_save (
168
+ self ,
169
+ checkpoint_id : str ,
170
+ app_state : Dict [str , Stateful ],
171
+ planner : Optional [SavePlanner ] = None ,
172
+ storage_writer : Optional [StorageWriter ] = None ,
173
+ ) -> bool :
174
+
175
+ if planner is None :
176
+ planner = DefaultSavePlanner ()
177
+
178
+ if storage_writer is None :
179
+ storage_writer = Writer (checkpoint_id , ** self .default_writer_options )
159
180
160
181
if self ._prev_snapshot is not None :
161
182
if not self ._prev_snapshot .done ():
@@ -177,24 +198,42 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo
177
198
178
199
self ._prev_snapshot = dcp .async_save (
179
200
state_dict = {"app_state" : MultiStateful (app_state )},
201
+ checkpoint_id = checkpoint_id ,
180
202
process_group = self ._process_group ,
181
- storage_writer = Writer (checkpoint_id , ** self .default_writer_options ),
203
+ storage_writer = storage_writer ,
204
+ planner = planner ,
182
205
)
183
206
184
207
return True
185
208
186
- def _save (self , checkpoint_id : str , app_state : Dict [str , Stateful ]) -> bool :
209
+ def _save (
210
+ self ,
211
+ checkpoint_id : str ,
212
+ app_state : Dict [str , Stateful ],
213
+ planner : Optional [SavePlanner ] = None ,
214
+ storage_writer : Optional [StorageWriter ] = None ,
215
+ ) -> bool :
216
+ # Initialize DefaultSavePlanner and FsspecWriter if not provided
217
+ if planner is None :
218
+ planner = DefaultSavePlanner ()
219
+
220
+ if storage_writer is None :
221
+ storage_writer = Writer (checkpoint_id , ** self .default_writer_options )
222
+
187
223
try :
188
224
dcp .save (
189
225
state_dict = {"app_state" : MultiStateful (app_state )},
226
+ checkpoint_id = checkpoint_id ,
190
227
process_group = self ._process_group ,
191
- storage_writer = Writer (checkpoint_id , ** self .default_writer_options ),
228
+ storage_writer = storage_writer ,
229
+ planner = planner ,
192
230
)
193
231
except AttributeError :
194
232
dcp .save_state_dict (
195
233
state_dict = {"app_state" : MultiStateful (app_state )},
196
234
process_group = self ._process_group ,
197
- storage_writer = Writer (checkpoint_id , ** self .default_writer_options ),
235
+ storage_writer = storage_writer ,
236
+ planner = planner ,
198
237
)
199
238
200
239
return True
@@ -229,13 +268,8 @@ def restore(
229
268
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) Note:
230
269
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
231
270
restore_options: Controls what to filter when restoring the state.
232
- knob_options: Option is kept for legacy reasons but ignored in DCP
271
+ knob_options: Additional keyword options for StorageWriter and StorageReader
233
272
"""
234
- if knob_options is not None :
235
- rank_zero_warn (
236
- "Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
237
- )
238
-
239
273
storage_reader = Reader (path )
240
274
241
275
restore_options = restore_options or RestoreOptions ()
@@ -250,6 +284,7 @@ def restore(
250
284
# request to restore the dataloader state only if
251
285
# the persisted snapshot state includes the dataloader entry
252
286
metadata = storage_reader .read_metadata ()
287
+
253
288
for key in metadata .state_dict_metadata .keys ():
254
289
if _TRAIN_DL_STATE_KEY in key :
255
290
app_state [_TRAIN_DL_STATE_KEY ] = train_dataloader
@@ -272,6 +307,7 @@ def restore(
272
307
try :
273
308
dcp .load (
274
309
{"app_state" : MultiStateful (app_state )},
310
+ checkpoint_id = path ,
275
311
storage_reader = storage_reader ,
276
312
process_group = process_group ,
277
313
)
0 commit comments