Skip to content

Commit 8b42693

Browse files
authored
Update intermediate_source/FSDP_tutorial.rst
1 parent 4af6635 commit 8b42693

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

intermediate_source/FSDP_tutorial.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Model Initialization
7373
# )
7474
7575
We can inspect the nested wrapping with ``print(model)``. ``FSDPTransformer`` is a joint class of `Transformer <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L100>`_ and `FSDPModule
76-
https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details.
76+
<https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details.
7777

7878
**model.parameters() as DTensor**: ``fully_shard`` shards parameters across ranks, and convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor to represent sharded parameters. FSDP2 shards on dim-0 by default so DTensor placements are `Shard(dim=0)`. Say we have N ranks and a parameter with N rows before sharding. After sharding, each rank will have 1 row of the parameter. We can inspect sharded parameters using ``param.to_local()``.
7979

0 commit comments

Comments
 (0)