|
| 1 | +Asynchronous Saving with Distributed Checkpoint (DCP) |
| 2 | +===================================================== |
| 3 | + |
| 4 | +Checkpointing is often a bottle-neck in the critical distributed training workloads, incurring larger and larger costs as both model and world sizes grow. |
| 5 | +One excellent strategy to offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example |
| 6 | +from the `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__ |
| 7 | +to show how this can be integrated quite easily with `torch.distributed.checkpoint.async_save`. |
| 8 | + |
| 9 | + |
| 10 | +Notes on Asynchronous Checkpointing |
| 11 | +------------------------------------ |
| 12 | +Before getting started with Asynchronous Checkpointing, it's important that we discuss some differences and limitations as compared to synchronous checkpointing. |
| 13 | +Speciically: |
| 14 | + |
| 15 | +* Memory requirements - Asynchronous checkpointing works by first copying models into internal CPU-buffers. |
| 16 | + This is helpful since it ensures model and optimizer weights are not changing while the model is still checkpointing, |
| 17 | + but does raise CPU memory by a factor of checkpoint size times the number of process on the host. |
| 18 | + |
| 19 | +* Checkpoint Management - Since checkpointing is Asynchronous, it is up to the user to manage concurrently run checkpoints. In general users can |
| 20 | + employ their own management strategies by handling the future object returned form `async_save`. For most users, we recommend limiting |
| 21 | + checkpoints to one asynchronous request at a time, avoiding additional memory pressure per request. |
| 22 | + |
| 23 | + |
| 24 | + |
| 25 | +.. code-block:: python |
| 26 | +
|
| 27 | + import os |
| 28 | +
|
| 29 | + import torch |
| 30 | + import torch.distributed as dist |
| 31 | + import torch.distributed.checkpoint as dcp |
| 32 | + import torch.multiprocessing as mp |
| 33 | + import torch.nn as nn |
| 34 | +
|
| 35 | + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 36 | + from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict |
| 37 | + from torch.distributed.checkpoint.stateful import Stateful |
| 38 | + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType |
| 39 | +
|
| 40 | + CHECKPOINT_DIR = "checkpoint" |
| 41 | +
|
| 42 | +
|
| 43 | + class AppState(Stateful): |
| 44 | + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant |
| 45 | + with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the |
| 46 | + dcp.save/load APIs. |
| 47 | +
|
| 48 | + Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model |
| 49 | + and optimizer. |
| 50 | + """ |
| 51 | +
|
| 52 | + def __init__(self, model, optimizer=None): |
| 53 | + self.model = model |
| 54 | + self.optimizer = optimizer |
| 55 | +
|
| 56 | + def state_dict(self): |
| 57 | + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT |
| 58 | + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) |
| 59 | + return { |
| 60 | + "model": model_state_dict, |
| 61 | + "optim": optimizer_state_dict |
| 62 | + } |
| 63 | +
|
| 64 | + def load_state_dict(self, state_dict): |
| 65 | + # sets our state dicts on the model and optimizer, now that we've loaded |
| 66 | + set_state_dict( |
| 67 | + self.model, |
| 68 | + self.optimizer, |
| 69 | + model_state_dict=state_dict["model"], |
| 70 | + optim_state_dict=state_dict["optim"] |
| 71 | + ) |
| 72 | +
|
| 73 | + class ToyModel(nn.Module): |
| 74 | + def __init__(self): |
| 75 | + super(ToyModel, self).__init__() |
| 76 | + self.net1 = nn.Linear(16, 16) |
| 77 | + self.relu = nn.ReLU() |
| 78 | + self.net2 = nn.Linear(16, 8) |
| 79 | +
|
| 80 | + def forward(self, x): |
| 81 | + return self.net2(self.relu(self.net1(x))) |
| 82 | +
|
| 83 | +
|
| 84 | + def setup(rank, world_size): |
| 85 | + os.environ["MASTER_ADDR"] = "localhost" |
| 86 | + os.environ["MASTER_PORT"] = "12355 " |
| 87 | +
|
| 88 | + # initialize the process group |
| 89 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 90 | + torch.cuda.set_device(rank) |
| 91 | +
|
| 92 | +
|
| 93 | + def cleanup(): |
| 94 | + dist.destroy_process_group() |
| 95 | +
|
| 96 | +
|
| 97 | + def run_fsdp_checkpoint_save_example(rank, world_size): |
| 98 | + print(f"Running basic FSDP checkpoint saving example on rank {rank}.") |
| 99 | + setup(rank, world_size) |
| 100 | +
|
| 101 | + # create a model and move it to GPU with id rank |
| 102 | + model = ToyModel().to(rank) |
| 103 | + model = FSDP(model) |
| 104 | +
|
| 105 | + loss_fn = nn.MSELoss() |
| 106 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) |
| 107 | +
|
| 108 | + checkpoint_future = None |
| 109 | + for step in range(10): |
| 110 | + optimizer.zero_grad() |
| 111 | + model(torch.rand(8, 16, device="cuda")).sum().backward() |
| 112 | + optimizer.step() |
| 113 | +
|
| 114 | + state_dict = { "app": AppState(model, optimizer) } |
| 115 | + if checkpoint_future is not None: |
| 116 | + # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time |
| 117 | + checkpoint_future.result() |
| 118 | + dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}") |
| 119 | +
|
| 120 | + cleanup() |
| 121 | +
|
| 122 | +
|
| 123 | + if __name__ == "__main__": |
| 124 | + world_size = torch.cuda.device_count() |
| 125 | + print(f"Running fsdp checkpoint example on {world_size} devices.") |
| 126 | + mp.spawn( |
| 127 | + run_fsdp_checkpoint_save_example, |
| 128 | + args=(world_size,), |
| 129 | + nprocs=world_size, |
| 130 | + join=True, |
| 131 | + ) |
| 132 | +
|
| 133 | +
|
| 134 | +Even more performance with Pinned Memory |
| 135 | +----------------------------------------- |
| 136 | +If the above optimization is still not performant enough for a use case, PyTorch offers an additional optimization for GPU models by utilizing a pinned memory buffer. |
| 137 | +This optimization attacks the main overhead of asynchronous checkpointing, which is the in-memory copying to checkpointing buffers. |
| 138 | + |
| 139 | +Note: The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without the pinned memory optimization (as demonstrated above), |
| 140 | +any checkpointing buffers are released as soon as checkpointing is finished. With the pinned memory implementation, this buffer is maintained in between steps, leading to the same |
| 141 | +peak memory pressure being sustained through the application life. |
| 142 | + |
| 143 | + |
| 144 | +.. code-block:: python |
| 145 | +
|
| 146 | + import os |
| 147 | +
|
| 148 | + import torch |
| 149 | + import torch.distributed as dist |
| 150 | + import torch.distributed.checkpoint as dcp |
| 151 | + import torch.multiprocessing as mp |
| 152 | + import torch.nn as nn |
| 153 | +
|
| 154 | + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 155 | + from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict |
| 156 | + from torch.distributed.checkpoint.stateful import Stateful |
| 157 | + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType |
| 158 | + from torch.distributed.checkpoint import StorageWriter |
| 159 | +
|
| 160 | + CHECKPOINT_DIR = "checkpoint" |
| 161 | +
|
| 162 | +
|
| 163 | + class AppState(Stateful): |
| 164 | + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant |
| 165 | + with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the |
| 166 | + dcp.save/load APIs. |
| 167 | +
|
| 168 | + Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model |
| 169 | + and optimizer. |
| 170 | + """ |
| 171 | +
|
| 172 | + def __init__(self, model, optimizer=None): |
| 173 | + self.model = model |
| 174 | + self.optimizer = optimizer |
| 175 | +
|
| 176 | + def state_dict(self): |
| 177 | + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT |
| 178 | + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) |
| 179 | + return { |
| 180 | + "model": model_state_dict, |
| 181 | + "optim": optimizer_state_dict |
| 182 | + } |
| 183 | +
|
| 184 | + def load_state_dict(self, state_dict): |
| 185 | + # sets our state dicts on the model and optimizer, now that we've loaded |
| 186 | + set_state_dict( |
| 187 | + self.model, |
| 188 | + self.optimizer, |
| 189 | + model_state_dict=state_dict["model"], |
| 190 | + optim_state_dict=state_dict["optim"] |
| 191 | + ) |
| 192 | +
|
| 193 | + class ToyModel(nn.Module): |
| 194 | + def __init__(self): |
| 195 | + super(ToyModel, self).__init__() |
| 196 | + self.net1 = nn.Linear(16, 16) |
| 197 | + self.relu = nn.ReLU() |
| 198 | + self.net2 = nn.Linear(16, 8) |
| 199 | +
|
| 200 | + def forward(self, x): |
| 201 | + return self.net2(self.relu(self.net1(x))) |
| 202 | +
|
| 203 | +
|
| 204 | + def setup(rank, world_size): |
| 205 | + os.environ["MASTER_ADDR"] = "localhost" |
| 206 | + os.environ["MASTER_PORT"] = "12355 " |
| 207 | +
|
| 208 | + # initialize the process group |
| 209 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 210 | + torch.cuda.set_device(rank) |
| 211 | +
|
| 212 | +
|
| 213 | + def cleanup(): |
| 214 | + dist.destroy_process_group() |
| 215 | +
|
| 216 | +
|
| 217 | + def run_fsdp_checkpoint_save_example(rank, world_size): |
| 218 | + print(f"Running basic FSDP checkpoint saving example on rank {rank}.") |
| 219 | + setup(rank, world_size) |
| 220 | +
|
| 221 | + # create a model and move it to GPU with id rank |
| 222 | + model = ToyModel().to(rank) |
| 223 | + model = FSDP(model) |
| 224 | +
|
| 225 | + loss_fn = nn.MSELoss() |
| 226 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) |
| 227 | +
|
| 228 | + # The storage writer defines our 'staging' strategy, where staging is considered the process of copying |
| 229 | + # checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying |
| 230 | + # into a persistent buffer with pinned memory enabled. |
| 231 | + # Note: It's important that the writer persists in between checkpointing requests, since it maintains the |
| 232 | + # pinned memory buffer. |
| 233 | + writer = StorageWriter(cached_state_dict=True) |
| 234 | + checkpoint_future = None |
| 235 | + for step in range(10): |
| 236 | + optimizer.zero_grad() |
| 237 | + model(torch.rand(8, 16, device="cuda")).sum().backward() |
| 238 | + optimizer.step() |
| 239 | +
|
| 240 | + state_dict = { "app": AppState(model, optimizer) } |
| 241 | + if checkpoint_future is not None: |
| 242 | + # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time |
| 243 | + checkpoint_future.result() |
| 244 | + dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}") |
| 245 | +
|
| 246 | + cleanup() |
| 247 | +
|
| 248 | +
|
| 249 | + if __name__ == "__main__": |
| 250 | + world_size = torch.cuda.device_count() |
| 251 | + print(f"Running fsdp checkpoint example on {world_size} devices.") |
| 252 | + mp.spawn( |
| 253 | + run_fsdp_checkpoint_save_example, |
| 254 | + args=(world_size,), |
| 255 | + nprocs=world_size, |
| 256 | + join=True, |
| 257 | + ) |
0 commit comments