14
14
15
15
import torch
16
16
import torch .distributed as dist
17
+ from pyre_extensions import none_throws
17
18
from torch .distributed import checkpoint as dcp
18
19
from torch .distributed .checkpoint .default_planner import (
19
20
DefaultLoadPlanner ,
@@ -160,7 +161,7 @@ def _checkpoint_impl(
160
161
checkpoint_id , app_state , planner , storage_writer
161
162
)
162
163
if curr_snapshot_wait :
163
- self ._wait ()
164
+ self ._wait (log_warning = False )
164
165
else :
165
166
with get_timing_context (state , f"{ self .__class__ .__name__ } .save" ):
166
167
checkpoint_success = self ._save (
@@ -169,9 +170,42 @@ def _checkpoint_impl(
169
170
170
171
return checkpoint_success
171
172
172
- def _wait (self ) -> None :
173
- if self ._prev_snapshot is not None :
174
- self ._prev_snapshot .result ()
173
+ def _wait (self , log_warning : bool = True ) -> None :
174
+ """
175
+ If the previous async checkpoint is still running, wait for it to finish before continuing. Otherwise,
176
+ distributed collectives that use the checkpointing process group will result in a stuck job. This also
177
+ computes and logs the time spent waiting on the previous checkpoint to finish, and a toggable warning
178
+ for the user to modify checkpointing frequency.
179
+
180
+ If the previous checkpoing has already finished, this is a no-op.
181
+
182
+ Args:
183
+ log_warning: Toggle for logging a warning to the user to modify checkpointing frequency. Sometimes
184
+ this is not up to the user (e.g. on_exception, on_train_end).
185
+ """
186
+ if self ._prev_snapshot is None :
187
+ return
188
+
189
+ if self ._prev_snapshot .done ():
190
+ none_throws (self ._prev_snapshot ).result ()
191
+ return
192
+
193
+ if log_warning :
194
+ rank_zero_warn (
195
+ (
196
+ "Waiting on previous checkpoint to finish... Consider modifying checkpointing "
197
+ f"frequency if this is an issue. Current value (current { self ._save_every_n_train_steps } )"
198
+ ),
199
+ logger = logger ,
200
+ )
201
+
202
+ t0 = time .monotonic ()
203
+ none_throws (self ._prev_snapshot ).result ()
204
+
205
+ rank_zero_warn (
206
+ f"Waiting on previous checkpoint for { time .monotonic ()- t0 :.3f} seconds" ,
207
+ logger = logger ,
208
+ )
175
209
176
210
def _async_save (
177
211
self ,
@@ -187,24 +221,8 @@ def _async_save(
187
221
if storage_writer is None :
188
222
storage_writer = Writer (checkpoint_id , ** self .default_writer_options )
189
223
190
- if self ._prev_snapshot is not None :
191
- if not self ._prev_snapshot .done ():
192
- # TODO this is unreachable at this point, since we are waiting on other functions called before _checkpoint_impl.
193
- rank_zero_warn (
194
- (
195
- "Waiting on previous checkpoint to finish... Consider modifying checkpointing "
196
- f"frequency if this is an issue. Current value (current { self ._save_every_n_train_steps } )"
197
- ),
198
- logger = logger ,
199
- )
200
- t0 = time .monotonic ()
201
- self ._wait ()
202
- rank_zero_warn (
203
- f"Waiting on previous checkpoint for { time .monotonic ()- t0 :.3f} seconds" ,
204
- logger = logger ,
205
- )
206
- else :
207
- self ._wait ()
224
+ # Redundant check for safety
225
+ self ._wait (log_warning = True )
208
226
209
227
self ._prev_snapshot = dcp .async_save (
210
228
state_dict = {"app_state" : MultiStateful (app_state )},
@@ -257,7 +275,8 @@ def on_exception(
257
275
unit : Union [TTrainUnit , TEvalUnit , TPredictUnit ],
258
276
exc : BaseException ,
259
277
) -> None :
260
- self ._wait ()
278
+ rank_zero_info ("Ensuring previous async checkpoint finished before exiting." )
279
+ self ._wait (log_warning = False )
261
280
262
281
@staticmethod
263
282
def restore (
@@ -404,6 +423,7 @@ def _generate_checkpoint_and_upkeep(
404
423
# operations in the base class use the process group. So wait here instead.
405
424
self ._wait ()
406
425
426
+ # Note that every async checkpoint will be completed at this point.
407
427
return super ()._generate_checkpoint_and_upkeep (state , unit , hook )
408
428
409
429
@property
0 commit comments