Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion intermediate_source/FSDP_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Model Initialization
# )

We can inspect the nested wrapping with ``print(model)``. ``FSDPTransformer`` is a joint class of `Transformer <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L100>`_ and `FSDPModule
<​https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details.
https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details.

**model.parameters() as DTensor**: ``fully_shard`` shards parameters across ranks, and convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor to represent sharded parameters. FSDP2 shards on dim-0 by default so DTensor placements are `Shard(dim=0)`. Say we have N ranks and a parameter with N rows before sharding. After sharding, each rank will have 1 row of the parameter. We can inspect sharded parameters using ``param.to_local()``.

Expand Down