@@ -137,26 +137,38 @@ def _checkpoint_impl(
137
137
intra_epoch = hook == "on_train_step_end"
138
138
curr_snapshot_wait = hook == "on_train_end"
139
139
140
+ if planner is None :
141
+ planner = DefaultSavePlanner ()
142
+
143
+ if storage_writer is None :
144
+ storage_writer = Writer (checkpoint_id , ** self .default_writer_options )
145
+
140
146
app_state = _prepare_app_state_for_checkpoint (state , unit , intra_epoch )
141
147
# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
142
148
if self ._async_checkpoint :
143
149
with get_timing_context (state , f"{ self .__class__ .__name__ } .async_save" ):
144
- # TODO checkpoint is not truly successful
145
- # since this is async checkpointed, so in
146
- # future, add logic to set successful flag
147
- # only when checkpoint is fully written
148
- checkpoint_success = self ._async_save (
149
- checkpoint_id , app_state , planner , storage_writer
150
+ # Redundant check for safety
151
+ self ._wait (log_warning = True )
152
+ self ._prev_snapshot = dcp .async_save (
153
+ state_dict = {"app_state" : MultiStateful (app_state )},
154
+ checkpoint_id = checkpoint_id ,
155
+ process_group = self ._process_group ,
156
+ storage_writer = storage_writer ,
157
+ planner = planner ,
150
158
)
151
159
if curr_snapshot_wait :
152
160
self ._wait (log_warning = False )
153
161
else :
154
162
with get_timing_context (state , f"{ self .__class__ .__name__ } .save" ):
155
- checkpoint_success = self ._save (
156
- checkpoint_id , app_state , planner , storage_writer
163
+ dcp .save (
164
+ state_dict = {"app_state" : MultiStateful (app_state )},
165
+ checkpoint_id = checkpoint_id ,
166
+ process_group = self ._process_group ,
167
+ storage_writer = storage_writer ,
168
+ planner = planner ,
157
169
)
158
170
159
- return checkpoint_success
171
+ return True
160
172
161
173
def _wait (self , log_warning : bool = True ) -> None :
162
174
"""
@@ -195,57 +207,6 @@ def _wait(self, log_warning: bool = True) -> None:
195
207
logger = logger ,
196
208
)
197
209
198
- def _async_save (
199
- self ,
200
- checkpoint_id : str ,
201
- app_state : Dict [str , Stateful ],
202
- planner : Optional [SavePlanner ] = None ,
203
- storage_writer : Optional [StorageWriter ] = None ,
204
- ) -> bool :
205
-
206
- if planner is None :
207
- planner = DefaultSavePlanner ()
208
-
209
- if storage_writer is None :
210
- storage_writer = Writer (checkpoint_id , ** self .default_writer_options )
211
-
212
- # Redundant check for safety
213
- self ._wait (log_warning = True )
214
-
215
- self ._prev_snapshot = dcp .async_save (
216
- state_dict = {"app_state" : MultiStateful (app_state )},
217
- checkpoint_id = checkpoint_id ,
218
- process_group = self ._process_group ,
219
- storage_writer = storage_writer ,
220
- planner = planner ,
221
- )
222
-
223
- return True
224
-
225
- def _save (
226
- self ,
227
- checkpoint_id : str ,
228
- app_state : Dict [str , Stateful ],
229
- planner : Optional [SavePlanner ] = None ,
230
- storage_writer : Optional [StorageWriter ] = None ,
231
- ) -> bool :
232
- # Initialize DefaultSavePlanner and FsspecWriter if not provided
233
- if planner is None :
234
- planner = DefaultSavePlanner ()
235
-
236
- if storage_writer is None :
237
- storage_writer = Writer (checkpoint_id , ** self .default_writer_options )
238
-
239
- dcp .save (
240
- state_dict = {"app_state" : MultiStateful (app_state )},
241
- checkpoint_id = checkpoint_id ,
242
- process_group = self ._process_group ,
243
- storage_writer = storage_writer ,
244
- planner = planner ,
245
- )
246
-
247
- return True
248
-
249
210
def on_exception (
250
211
self ,
251
212
state : State ,
0 commit comments