diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index 8f93c2222d6..374b5af2b7b 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict @@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is: import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict import torch.multiprocessing as mp import torch.nn as nn @@ -202,7 +203,7 @@ The reason that we need the ``state_dict`` prior to loading is: def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict @@ -252,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is: optimizer = torch.optim.Adam(model.parameters(), lr=0.1) state_dict = { "app": AppState(model, optimizer)} - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - # generates the state dict we will load into - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) - state_dict = { - "model": model_state_dict, - "optimizer": optimizer_state_dict - } dcp.load( state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR,