Skip to content

Commit f4ecb4c

Browse files
committed
adds async save and stateful info
1 parent cad4839 commit f4ecb4c

File tree

2 files changed

+326
-22
lines changed

2 files changed

+326
-22
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
Asynchronous Saving with Distributed Checkpoint (DCP)
2+
=====================================================
3+
4+
Checkpointing is often a bottle-neck in the critical distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
5+
One excellent strategy to offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example
6+
from the `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__
7+
to show how this can be integrated quite easily with `torch.distributed.checkpoint.async_save`.
8+
9+
10+
Notes on Asynchronous Checkpointing
11+
------------------------------------
12+
Before getting started with Asynchronous Checkpointing, it's important that we discuss some differences and limitations as compared to synchronous checkpointing.
13+
Speciically:
14+
15+
* Memory requirements - Asynchronous checkpointing works by first copying models into internal CPU-buffers.
16+
This is helpful since it ensures model and optimizer weights are not changing while the model is still checkpointing,
17+
but does raise CPU memory by a factor of checkpoint size times the number of process on the host.
18+
19+
* Checkpoint Management - Since checkpointing is Asynchronous, it is up to the user to manage concurrently run checkpoints. In general users can
20+
employ their own management strategies by handling the future object returned form `async_save`. For most users, we recommend limiting
21+
checkpoints to one asynchronous request at a time, avoiding additional memory pressure per request.
22+
23+
24+
25+
.. code-block:: python
26+
27+
import os
28+
29+
import torch
30+
import torch.distributed as dist
31+
import torch.distributed.checkpoint as dcp
32+
import torch.multiprocessing as mp
33+
import torch.nn as nn
34+
35+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
36+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
37+
from torch.distributed.checkpoint.stateful import Stateful
38+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
39+
40+
CHECKPOINT_DIR = "checkpoint"
41+
42+
43+
class AppState(Stateful):
44+
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
45+
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
46+
dcp.save/load APIs.
47+
48+
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
49+
and optimizer.
50+
"""
51+
52+
def __init__(self, model, optimizer=None):
53+
self.model = model
54+
self.optimizer = optimizer
55+
56+
def state_dict(self):
57+
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
58+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
59+
return {
60+
"model": model_state_dict,
61+
"optim": optimizer_state_dict
62+
}
63+
64+
def load_state_dict(self, state_dict):
65+
# sets our state dicts on the model and optimizer, now that we've loaded
66+
set_state_dict(
67+
self.model,
68+
self.optimizer,
69+
model_state_dict=state_dict["model"],
70+
optim_state_dict=state_dict["optim"]
71+
)
72+
73+
class ToyModel(nn.Module):
74+
def __init__(self):
75+
super(ToyModel, self).__init__()
76+
self.net1 = nn.Linear(16, 16)
77+
self.relu = nn.ReLU()
78+
self.net2 = nn.Linear(16, 8)
79+
80+
def forward(self, x):
81+
return self.net2(self.relu(self.net1(x)))
82+
83+
84+
def setup(rank, world_size):
85+
os.environ["MASTER_ADDR"] = "localhost"
86+
os.environ["MASTER_PORT"] = "12355 "
87+
88+
# initialize the process group
89+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
90+
torch.cuda.set_device(rank)
91+
92+
93+
def cleanup():
94+
dist.destroy_process_group()
95+
96+
97+
def run_fsdp_checkpoint_save_example(rank, world_size):
98+
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
99+
setup(rank, world_size)
100+
101+
# create a model and move it to GPU with id rank
102+
model = ToyModel().to(rank)
103+
model = FSDP(model)
104+
105+
loss_fn = nn.MSELoss()
106+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
107+
108+
checkpoint_future = None
109+
for step in range(10):
110+
optimizer.zero_grad()
111+
model(torch.rand(8, 16, device="cuda")).sum().backward()
112+
optimizer.step()
113+
114+
state_dict = { "app": AppState(model, optimizer) }
115+
if checkpoint_future is not None:
116+
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
117+
checkpoint_future.result()
118+
dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
119+
120+
cleanup()
121+
122+
123+
if __name__ == "__main__":
124+
world_size = torch.cuda.device_count()
125+
print(f"Running fsdp checkpoint example on {world_size} devices.")
126+
mp.spawn(
127+
run_fsdp_checkpoint_save_example,
128+
args=(world_size,),
129+
nprocs=world_size,
130+
join=True,
131+
)
132+
133+
134+
Even more performance with Pinned Memory
135+
-----------------------------------------
136+
If the above optimization is still not performant enough for a use case, PyTorch offers an additional optimization for GPU models by utilizing a pinned memory buffer.
137+
This optimization attacks the main overhead of asynchronous checkpointing, which is the in-memory copying to checkpointing buffers.
138+
139+
Note: The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without the pinned memory optimization (as demonstrated above),
140+
any checkpointing buffers are released as soon as checkpointing is finished. With the pinned memory implementation, this buffer is maintained in between steps, leading to the same
141+
peak memory pressure being sustained through the application life.
142+
143+
144+
.. code-block:: python
145+
146+
import os
147+
148+
import torch
149+
import torch.distributed as dist
150+
import torch.distributed.checkpoint as dcp
151+
import torch.multiprocessing as mp
152+
import torch.nn as nn
153+
154+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
155+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
156+
from torch.distributed.checkpoint.stateful import Stateful
157+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
158+
from torch.distributed.checkpoint import StorageWriter
159+
160+
CHECKPOINT_DIR = "checkpoint"
161+
162+
163+
class AppState(Stateful):
164+
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
165+
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
166+
dcp.save/load APIs.
167+
168+
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
169+
and optimizer.
170+
"""
171+
172+
def __init__(self, model, optimizer=None):
173+
self.model = model
174+
self.optimizer = optimizer
175+
176+
def state_dict(self):
177+
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
178+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
179+
return {
180+
"model": model_state_dict,
181+
"optim": optimizer_state_dict
182+
}
183+
184+
def load_state_dict(self, state_dict):
185+
# sets our state dicts on the model and optimizer, now that we've loaded
186+
set_state_dict(
187+
self.model,
188+
self.optimizer,
189+
model_state_dict=state_dict["model"],
190+
optim_state_dict=state_dict["optim"]
191+
)
192+
193+
class ToyModel(nn.Module):
194+
def __init__(self):
195+
super(ToyModel, self).__init__()
196+
self.net1 = nn.Linear(16, 16)
197+
self.relu = nn.ReLU()
198+
self.net2 = nn.Linear(16, 8)
199+
200+
def forward(self, x):
201+
return self.net2(self.relu(self.net1(x)))
202+
203+
204+
def setup(rank, world_size):
205+
os.environ["MASTER_ADDR"] = "localhost"
206+
os.environ["MASTER_PORT"] = "12355 "
207+
208+
# initialize the process group
209+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
210+
torch.cuda.set_device(rank)
211+
212+
213+
def cleanup():
214+
dist.destroy_process_group()
215+
216+
217+
def run_fsdp_checkpoint_save_example(rank, world_size):
218+
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
219+
setup(rank, world_size)
220+
221+
# create a model and move it to GPU with id rank
222+
model = ToyModel().to(rank)
223+
model = FSDP(model)
224+
225+
loss_fn = nn.MSELoss()
226+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
227+
228+
# The storage writer defines our 'staging' strategy, where staging is considered the process of copying
229+
# checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
230+
# into a persistent buffer with pinned memory enabled.
231+
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
232+
# pinned memory buffer.
233+
writer = StorageWriter(cached_state_dict=True)
234+
checkpoint_future = None
235+
for step in range(10):
236+
optimizer.zero_grad()
237+
model(torch.rand(8, 16, device="cuda")).sum().backward()
238+
optimizer.step()
239+
240+
state_dict = { "app": AppState(model, optimizer) }
241+
if checkpoint_future is not None:
242+
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
243+
checkpoint_future.result()
244+
dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
245+
246+
cleanup()
247+
248+
249+
if __name__ == "__main__":
250+
world_size = torch.cuda.device_count()
251+
print(f"Running fsdp checkpoint example on {world_size} devices.")
252+
mp.spawn(
253+
run_fsdp_checkpoint_save_example,
254+
args=(world_size,),
255+
nprocs=world_size,
256+
join=True,
257+
)

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ DCP is different from :func:`torch.save` and :func:`torch.load` in a few signifi
3333

3434
* It produces multiple files per checkpoint, with at least one per rank.
3535
* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
36+
* DCP offers special handling of Stateful objects (formally defined in `torch.distributed.checkpoint.stateful`), automatically calling both `state_dict` and `load_state_dict` methods if they are defined.
3637

3738
.. note::
3839
The code in this tutorial runs on an 8-GPU server, but it can be easily
@@ -59,12 +60,43 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
5960
import torch.nn as nn
6061
6162
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
62-
from torch.distributed.checkpoint.state_dict import get_state_dict
63+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
64+
from torch.distributed.checkpoint.stateful import Stateful
6365
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
6466
6567
CHECKPOINT_DIR = "checkpoint"
6668
6769
70+
class AppState(Stateful):
71+
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
72+
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
73+
dcp.save/load APIs.
74+
75+
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
76+
and optimizer.
77+
"""
78+
79+
def __init__(self, model, optimizer=None):
80+
self.model = model
81+
self.optimizer = optimizer
82+
83+
def state_dict(self):
84+
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
85+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
86+
return {
87+
"model": model_state_dict,
88+
"optim": optimizer_state_dict
89+
}
90+
91+
def load_state_dict(self, state_dict):
92+
# sets our state dicts on the model and optimizer, now that we've loaded
93+
set_state_dict(
94+
self.model,
95+
self.optimizer,
96+
model_state_dict=state_dict["model"],
97+
optim_state_dict=state_dict["optim"]
98+
)
99+
68100
class ToyModel(nn.Module):
69101
def __init__(self):
70102
super(ToyModel, self).__init__()
@@ -104,14 +136,8 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
104136
model(torch.rand(8, 16, device="cuda")).sum().backward()
105137
optimizer.step()
106138
107-
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
108-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
109-
state_dict = {
110-
"model": model_state_dict,
111-
"optimizer": optimizer_state_dict
112-
}
113-
dcp.save(state_dict,checkpoint_id=CHECKPOINT_DIR)
114-
139+
state_dict = { "app": AppState(model, optimizer) }
140+
dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)
115141
116142
cleanup()
117143
@@ -161,6 +187,36 @@ The reason that we need the ``state_dict`` prior to loading is:
161187
CHECKPOINT_DIR = "checkpoint"
162188
163189
190+
class AppState(Stateful):
191+
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
192+
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
193+
dcp.save/load APIs.
194+
195+
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
196+
and optimizer.
197+
"""
198+
199+
def __init__(self, model, optimizer=None):
200+
self.model = model
201+
self.optimizer = optimizer
202+
203+
def state_dict(self):
204+
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
205+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
206+
return {
207+
"model": model_state_dict,
208+
"optim": optimizer_state_dict
209+
}
210+
211+
def load_state_dict(self, state_dict):
212+
# sets our state dicts on the model and optimizer, now that we've loaded
213+
set_state_dict(
214+
self.model,
215+
self.optimizer,
216+
model_state_dict=state_dict["model"],
217+
optim_state_dict=state_dict["optim"]
218+
)
219+
164220
class ToyModel(nn.Module):
165221
def __init__(self):
166222
super(ToyModel, self).__init__()
@@ -193,23 +249,14 @@ The reason that we need the ``state_dict`` prior to loading is:
193249
model = ToyModel().to(rank)
194250
model = FSDP(model)
195251
196-
# generates the state dict we will load into
197-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
198-
state_dict = {
199-
"model": model_state_dict,
200-
"optimizer": optimizer_state_dict
201-
}
252+
loss_fn = nn.MSELoss()
253+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
254+
255+
state_dict = { "app": AppState(model, optimizer)}
202256
dcp.load(
203257
state_dict=state_dict,
204258
checkpoint_id=CHECKPOINT_DIR,
205259
)
206-
# sets our state dicts on the model and optimizer, now that we've loaded
207-
set_state_dict(
208-
model,
209-
optimizer,
210-
model_state_dict=model_state_dict,
211-
optim_state_dict=optimizer_state_dict
212-
)
213260
214261
cleanup()
215262

0 commit comments

Comments
 (0)