You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/checkpointing.rst
+15-15Lines changed: 15 additions & 15 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,24 +1,24 @@
1
1
Checkpointing
2
2
================================
3
3
4
-
TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot<https://pytorch.org/torchsnapshot/main/>`_ under the hood.
4
+
TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP<https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint>`_ under the hood.
5
5
6
6
.. code-block:: python
7
7
8
8
module = nn.Linear(input_dim, 1)
9
9
unit = MyUnit(module=module)
10
-
tss=TorchSnapshotSaver(
10
+
dcp=DistributedCheckpointSaver(
11
11
dirpath=your_dirpath_here,
12
12
save_every_n_train_steps=100,
13
13
save_every_n_epochs=2,
14
14
)
15
15
# loads latest checkpoint, if it exists
16
16
if latest_checkpoint_dir:
17
-
tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
17
+
dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
18
18
train(
19
19
unit,
20
20
dataloader,
21
-
callbacks=[tss]
21
+
callbacks=[dcp]
22
22
)
23
23
24
24
There is built-in support for saving and loading distributed models (DDP, FSDP).
@@ -37,15 +37,15 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
# checkpointer callback will use state dict type specified in FSDPStrategy
48
-
callbacks=[tss]
48
+
callbacks=[dcp]
49
49
)
50
50
51
51
Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_.
@@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.or
@@ -74,15 +74,15 @@ When finetuning your models, you can pass RestoreOptions to avoid loading optimi
74
74
75
75
.. code-block:: python
76
76
77
-
tss=TorchSnapshotSaver(
77
+
dcp=DistributedCheckpointSaver(
78
78
dirpath=your_dirpath_here,
79
79
save_every_n_train_steps=100,
80
80
save_every_n_epochs=2,
81
81
)
82
82
83
83
# loads latest checkpoint, if it exists
84
84
if latest_checkpoint_dir:
85
-
tss.restore_from_latest(
85
+
dcp.restore_from_latest(
86
86
your_dirpath_here,
87
87
your_unit,
88
88
train_dataloader=dataloader,
@@ -99,7 +99,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
99
99
100
100
module = nn.Linear(input_dim, 1)
101
101
unit = MyUnit(module=module)
102
-
tss=TorchSnapshotSaver(
102
+
dcp=DistributedCheckpointSaver(
103
103
dirpath=your_dirpath_here,
104
104
save_every_n_epochs=1,
105
105
best_checkpoint_config=BestCheckpointConfig(
@@ -111,7 +111,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
111
111
train(
112
112
unit,
113
113
dataloader,
114
-
callbacks=[tss]
114
+
callbacks=[dcp]
115
115
)
116
116
117
117
By specifying the monitored metric to be "train_loss", the checkpointer will expect the :class:`~torchtnt.framework.unit.TrainUnit` to have a "train_loss" attribute at the time of checkpointing, and it will cast this value to a float and append the value to the checkpoint path name. This attribute is expected to be computed and kept up to date appropriately in the unit by the user.
@@ -120,13 +120,13 @@ Later on, the best checkpoint can be loaded via
120
120
121
121
.. code-block:: python
122
122
123
-
TorchSnapshotSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
123
+
DistributedCheckpointSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
124
124
125
125
If you'd like to monitor a validation metric (say validation loss after each eval epoch during :py:func:`~torchtnt.framework.fit.fit`), you can use the `save_every_n_eval_epochs` flag instead, like so
126
126
127
127
.. code-block:: python
128
128
129
-
tss=TorchSnapshotSaver(
129
+
dcp=DistributedCheckpointSaver(
130
130
dirpath=your_dirpath_here,
131
131
save_every_n_eval_epochs=1,
132
132
best_checkpoint_config=BestCheckpointConfig(
@@ -139,7 +139,7 @@ And to save only the top three performing models, you can use the existing `keep
0 commit comments