You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -54,10 +65,12 @@ class DistributedCheckpointSaver(BaseCheckpointer):
54
65
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
55
66
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
56
67
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
68
+
async_checkpoint: Whether to perform asynchronous checkpointing. Default: ``True``.
69
+
knob_options: Additional keyword options for StorageWriter. <https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.StorageWriter/>
57
70
58
71
Note:
59
-
If torch.distributed is available and default process group is initialized, dcp's `no_dist <https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.load_state_dict/>_`
60
-
argument is automatically set to False. Otherwise it's set to True.
72
+
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
73
+
Additionally, a gloo process group must be initialized for async_checkpoint. For workloads that require nccl, the recommended initialization is 'cpu:gloo,cuda:nccl'
61
74
62
75
Note:
63
76
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_ prior to starting training.
@@ -67,6 +80,8 @@ class DistributedCheckpointSaver(BaseCheckpointer):
67
80
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""Utility method to restore dcp checkpoint from a path.
128
203
@@ -133,10 +208,16 @@ def restore(
133
208
path: Path of the snapshot to restore.
134
209
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
135
210
train_dataloader: An optional train dataloader to restore.
136
-
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
211
+
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) Note:
212
+
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
137
213
restore_options: Controls what to filter when restoring the state.
138
-
no_dist: Set to true if loading in non-distributed setting
214
+
knob_options: Option is kept for legacy reasons but ignored in DCP
139
215
"""
216
+
ifknob_optionsisnotNone:
217
+
rank_zero_warn(
218
+
"Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
219
+
)
220
+
140
221
storage_reader=FsspecReader(path)
141
222
142
223
restore_options=restore_optionsorRestoreOptions()
@@ -161,13 +242,37 @@ def restore(
161
242
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
0 commit comments