Skip to content

Commit a941fb0

Browse files
committed
type updates
1 parent 31f1066 commit a941fb0

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

recipes_source/distributed_async_checkpoint_recipe.rst

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Asynchronous Saving with Distributed Checkpoint (DCP)
22
=====================================================
33

4-
Checkpointing is often a bottle-neck in the critical distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
5-
One excellent strategy to offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example
4+
Checkpointing is often a bottle-neck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
5+
One excellent strategy for offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example
66
from the `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__
77
to show how this can be integrated quite easily with `torch.distributed.checkpoint.async_save`.
88

@@ -111,18 +111,19 @@ Speciically:
111111
model(torch.rand(8, 16, device="cuda")).sum().backward()
112112
optimizer.step()
113113
114-
state_dict = { "app": AppState(model, optimizer) }
114+
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
115115
if checkpoint_future is not None:
116-
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
117116
checkpoint_future.result()
118-
dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
117+
118+
state_dict = { "app": AppState(model, optimizer) }
119+
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
119120
120121
cleanup()
121122
122123
123124
if __name__ == "__main__":
124125
world_size = torch.cuda.device_count()
125-
print(f"Running fsdp checkpoint example on {world_size} devices.")
126+
print(f"Running async checkpoint example on {world_size} devices.")
126127
mp.spawn(
127128
run_fsdp_checkpoint_save_example,
128129
args=(world_size,),
@@ -133,8 +134,9 @@ Speciically:
133134
134135
Even more performance with Pinned Memory
135136
-----------------------------------------
136-
If the above optimization is still not performant enough for a use case, PyTorch offers an additional optimization for GPU models by utilizing a pinned memory buffer.
137-
This optimization attacks the main overhead of asynchronous checkpointing, which is the in-memory copying to checkpointing buffers.
137+
If the above optimization is still not performant enough, users may wish to take advantage of an additional optimization for GPU models which utilizes a pinned memory buffer for checkpoint staging.
138+
Specifically, this optimization attacks the main overhead of asynchronous checkpointing, which is the in-memory copying to checkpointing buffers. By maintaing a pinned memory buffer between
139+
checkpoint requests users can take advantage of direct memory access to speed up this copy.
138140

139141
Note: The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without the pinned memory optimization (as demonstrated above),
140142
any checkpointing buffers are released as soon as checkpointing is finished. With the pinned memory implementation, this buffer is maintained in between steps, leading to the same

0 commit comments

Comments
 (0)