|
200 | 200 | # In order to train models with massive embedding tables, sharding these
|
201 | 201 | # tables across GPUs is required, which then introduces a whole new set of
|
202 | 202 | # problems and opportunities in parallelism and optimization. Luckily, we have
|
203 |
| -# the TorchRec library that has encountered, consolidated, and addressed |
| 203 | +# the TorchRec library <https://docs.pytorch.org/torchrec/overview.html>`__ that has encountered, consolidated, and addressed |
204 | 204 | # many of these concerns. TorchRec serves as a **library that provides
|
205 | 205 | # primitives for large scale distributed embeddings**.
|
206 | 206 | #
|
|
496 | 496 | #
|
497 | 497 | # * **The module sharder**: This class exposes a ``shard`` API
|
498 | 498 | # that handles sharding a TorchRec Module, producing a sharded module.
|
499 |
| -# * For ``EmbeddingBagCollection``, the sharder is `EmbeddingBagCollectionSharder <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__ |
| 499 | +# * For ``EmbeddingBagCollection``, the sharder is `EmbeddingBagCollectionSharder ` |
500 | 500 | # * **Sharded module**: This class is a sharded variant of a TorchRec module.
|
501 | 501 | # It has the same input/output as a the regular TorchRec module, but much
|
502 | 502 | # more optimized and works in a distributed environment.
|
503 |
| -# * For ``EmbeddingBagCollection``, the sharded variant is `ShardedEmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__ |
| 503 | +# * For ``EmbeddingBagCollection``, the sharded variant is `ShardedEmbeddingBagCollection` |
504 | 504 | #
|
505 | 505 | # Every TorchRec module has an unsharded and sharded variant.
|
506 | 506 | #
|
|
619 | 619 | # Remember that TorchRec is a highly optimized library for distributed
|
620 | 620 | # embeddings. A concept that TorchRec introduces to enable higher
|
621 | 621 | # performance for training on GPU is a
|
622 |
| -# `LazyAwaitable <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__. |
| 622 | +# `LazyAwaitable `. |
623 | 623 | # You will see ``LazyAwaitable`` types as outputs of various sharded
|
624 | 624 | # TorchRec modules. All a ``LazyAwaitable`` type does is delay calculating some
|
625 | 625 | # result as long as possible, and it does it by acting like an async type.
|
@@ -693,7 +693,7 @@ def _wait_impl(self) -> torch.Tensor:
|
693 | 693 | # order for distribution of gradients. ``input_dist``, ``lookup``, and
|
694 | 694 | # ``output_dist`` all depend on the sharding scheme. Since we sharded in a
|
695 | 695 | # table-wise fashion, these APIs are modules that are constructed by
|
696 |
| -# `TwPooledEmbeddingSharding <https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding>`__. |
| 696 | +# `TwPooledEmbeddingSharding`. |
697 | 697 | #
|
698 | 698 |
|
699 | 699 | sharded_ebc
|
@@ -742,7 +742,7 @@ def _wait_impl(self) -> torch.Tensor:
|
742 | 742 | # ``EmbeddingBagCollection`` to generate a
|
743 | 743 | # ``ShardedEmbeddingBagCollection`` module. This workflow is fine, but
|
744 | 744 | # typically when implementing model parallel,
|
745 |
| -# `DistributedModelParallel <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__ |
| 745 | +# `DistributedModelParallel` |
746 | 746 | # (DMP) is used as the standard interface. When wrapping your model (in
|
747 | 747 | # our case ``ebc``), with DMP, the following will occur:
|
748 | 748 | #
|
|
0 commit comments