Skip to content

Commit f03fd9b

Browse files
saumishrfacebook-github-bot
authored andcommitted
Update Checkpoint docs with DCP based checkpointer (#904)
Summary: Pull Request resolved: #904 Update Checkpoint docs with DCP based checkpointer Reviewed By: JKSenthil Differential Revision: D63278746 fbshipit-source-id: 0c01e21fb516996001d3c12fe74504f65f5ed783
1 parent 843835c commit f03fd9b

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

docs/source/checkpointing.rst

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
Checkpointing
22
================================
33

4-
TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot <https://pytorch.org/torchsnapshot/main/>`_ under the hood.
4+
TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP <https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint>`_ under the hood.
55

66
.. code-block:: python
77
88
module = nn.Linear(input_dim, 1)
99
unit = MyUnit(module=module)
10-
tss = TorchSnapshotSaver(
10+
dcp = DistributedCheckpointSaver(
1111
dirpath=your_dirpath_here,
1212
save_every_n_train_steps=100,
1313
save_every_n_epochs=2,
1414
)
1515
# loads latest checkpoint, if it exists
1616
if latest_checkpoint_dir:
17-
tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
17+
dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
1818
train(
1919
unit,
2020
dataloader,
21-
callbacks=[tss]
21+
callbacks=[dcp]
2222
)
2323
2424
There is built-in support for saving and loading distributed models (DDP, FSDP).
@@ -37,15 +37,15 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
3737
)
3838
module = prepare_fsdp(module, strategy=fsdp_strategy)
3939
unit = MyUnit(module=module)
40-
tss = TorchSnapshotSaver(
40+
dcp = DistributedCheckpointSaver(
4141
dirpath=your_dirpath_here,
4242
save_every_n_epochs=2,
4343
)
4444
train(
4545
unit,
4646
dataloader,
4747
# checkpointer callback will use state dict type specified in FSDPStrategy
48-
callbacks=[tss]
48+
callbacks=[dcp]
4949
)
5050
5151
Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_.
@@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.or
5656
module = FSDP(module, ....)
5757
FSDP.set_state_dict_type(module, StateDictType.SHARDED_STATE_DICT)
5858
unit = MyUnit(module=module, ...)
59-
tss = TorchSnapshotSaver(
59+
dcp = DistributedCheckpointSaver(
6060
dirpath=your_dirpath_here,
6161
save_every_n_epochs=2,
6262
)
6363
train(
6464
unit,
6565
dataloader,
66-
callbacks=[tss]
66+
callbacks=[dcp]
6767
)
6868
6969
@@ -74,15 +74,15 @@ When finetuning your models, you can pass RestoreOptions to avoid loading optimi
7474

7575
.. code-block:: python
7676
77-
tss = TorchSnapshotSaver(
77+
dcp = DistributedCheckpointSaver(
7878
dirpath=your_dirpath_here,
7979
save_every_n_train_steps=100,
8080
save_every_n_epochs=2,
8181
)
8282
8383
# loads latest checkpoint, if it exists
8484
if latest_checkpoint_dir:
85-
tss.restore_from_latest(
85+
dcp.restore_from_latest(
8686
your_dirpath_here,
8787
your_unit,
8888
train_dataloader=dataloader,
@@ -99,7 +99,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
9999
100100
module = nn.Linear(input_dim, 1)
101101
unit = MyUnit(module=module)
102-
tss = TorchSnapshotSaver(
102+
dcp = DistributedCheckpointSaver(
103103
dirpath=your_dirpath_here,
104104
save_every_n_epochs=1,
105105
best_checkpoint_config=BestCheckpointConfig(
@@ -111,7 +111,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
111111
train(
112112
unit,
113113
dataloader,
114-
callbacks=[tss]
114+
callbacks=[dcp]
115115
)
116116
117117
By specifying the monitored metric to be "train_loss", the checkpointer will expect the :class:`~torchtnt.framework.unit.TrainUnit` to have a "train_loss" attribute at the time of checkpointing, and it will cast this value to a float and append the value to the checkpoint path name. This attribute is expected to be computed and kept up to date appropriately in the unit by the user.
@@ -120,13 +120,13 @@ Later on, the best checkpoint can be loaded via
120120

121121
.. code-block:: python
122122
123-
TorchSnapshotSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
123+
DistributedCheckpointSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
124124
125125
If you'd like to monitor a validation metric (say validation loss after each eval epoch during :py:func:`~torchtnt.framework.fit.fit`), you can use the `save_every_n_eval_epochs` flag instead, like so
126126

127127
.. code-block:: python
128128
129-
tss = TorchSnapshotSaver(
129+
dcp = DistributedCheckpointSaver(
130130
dirpath=your_dirpath_here,
131131
save_every_n_eval_epochs=1,
132132
best_checkpoint_config=BestCheckpointConfig(
@@ -139,7 +139,7 @@ And to save only the top three performing models, you can use the existing `keep
139139

140140
.. code-block:: python
141141
142-
tss = TorchSnapshotSaver(
142+
dcp = DistributedCheckpointSaver(
143143
dirpath=your_dirpath_here,
144144
save_every_n_eval_epochs=1,
145145
keep_last_n_checkpoints=3,

0 commit comments

Comments
 (0)