diff --git a/intermediate_source/FSDP_tutorial.rst b/intermediate_source/FSDP_tutorial.rst index ca54000218b..b0cdc19be23 100644 --- a/intermediate_source/FSDP_tutorial.rst +++ b/intermediate_source/FSDP_tutorial.rst @@ -240,7 +240,7 @@ We showcase how to convert a full state dict into a DTensor state dict for loadi * For the 1st time, it creates checkpoints for the model and optimizer * For the 2nd time, it loads from the previous checkpoint to resume training -**Loading state dicts**: We initialize the model under meta device and call ``fully_shard`` to convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor. After reading the full state dict from torch.load, we can call `distributed_tensor `_ to convert plain ``torch.Tensor`` into DTensor, using the same placements and device mesh from ``model.state_dict()``. Finally we can call `model.load_state_dict `_ to load DTensor state dicts into the model. +**Loading state dicts**: We initialize the model under meta device and call ``fully_shard`` to convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor. After reading the full state dict from torch.load, we can call `distribute_tensor `_ to convert plain ``torch.Tensor`` into DTensor, using the same placements and device mesh from ``model.state_dict()``. Finally we can call `model.load_state_dict `_ to load DTensor state dicts into the model. .. code-block:: python