27
27
28
28
###############################################
29
29
# Install Dependencies
30
- # ================
30
+ # ^^^^^^^^^^^^^^^^^^^^
31
31
#
32
32
# Before running this tutorial in Google Colab or other environment, install the
33
33
# following dependencies:
106
106
# Embeddings in PyTorch
107
107
# ---------------------
108
108
#
109
- # ``` torch.nn.Embedding`` <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__:
109
+ # `torch.nn.Embedding <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__:
110
110
# Embedding table where forward pass returns the embeddings themselves as
111
111
# is.
112
112
#
113
- # ``` torch.nn.EmbeddingBag`` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__:
113
+ # `torch.nn.EmbeddingBag <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__:
114
114
# Embedding table where forward pass returns embeddings that are then
115
115
# pooled, for example, sum or mean, otherwise known as **Pooled Embeddings**
116
116
#
173
173
174
174
######################################################################
175
175
# TorchRec
176
- # ========
176
+ # ^^^^^^^^
177
177
#
178
178
# Now you know how to use embedding tables, one of the foundations of
179
179
# modern recommendation systems! These tables represent entities and
218
218
219
219
######################################################################
220
220
# From ``EmbeddingBag`` to ``EmbeddingBagCollection``
221
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
221
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
222
222
#
223
223
# We have already explored
224
- # ``` torch.nn.Embedding`` <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__
224
+ # `torch.nn.Embedding <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__
225
225
# and
226
- # ``` torch.nn.EmbeddingBag`` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__.
226
+ # `torch.nn.EmbeddingBag <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__.
227
227
#
228
228
# TorchRec extends these modules by creating collections of embeddings, in
229
229
# other words modules that can have multiple embedding tables, with
230
- # ``` EmbeddingCollection`` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection>`__
230
+ # `EmbeddingCollection <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection>`__
231
231
# and
232
- # ``` EmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__.
232
+ # `EmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__.
233
233
# We will use ``EmbeddingBagCollection`` to represent a group of
234
234
# embedding bags.
235
235
#
346
346
)
347
347
user_jt = JaggedTensor (values = torch .tensor ([2 , 3 , 4 , 1 ]), lengths = torch .tensor ([2 , 2 ]))
348
348
349
- # Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
349
+ # Q1: How many batches are there, and which values are in the first batch for `` product_jt`` and `` user_jt`` ?
350
350
kjt = KeyedJaggedTensor .from_jt_dict ({"product" : product_jt , "user" : user_jt })
351
351
352
352
# Look at our feature keys for the ``KeyedJaggedTensor``
366
366
367
367
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
368
368
369
- # Now we can run a forward pass on our ``EmbeddingBagCollection``` from before
369
+ # Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
370
370
result = ebc (kjt )
371
371
result
372
372
488
488
# 1. **The module sharder**: This class exposes a ``shard`` API
489
489
# that handles sharding a TorchRec Module, producing a sharded module.
490
490
# * For ``EmbeddingBagCollection``, the sharder is
491
- # ``` EmbeddingBagCollectionSharder`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
491
+ # `EmbeddingBagCollectionSharder <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
492
492
# 2. **Sharded module**: This class is a sharded variant of a TorchRec module.
493
493
# It has the same input/output as a the regular TorchRec module, but much
494
494
# more optimized and works in a distributed environment.
495
495
# * For ``EmbeddingBagCollection``, the sharded variant is
496
- # ``` ShardedEmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
496
+ # `ShardedEmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
497
497
#
498
498
# Every TorchRec module has an unsharded and sharded variant. \* The
499
499
# unsharded version is meant to be prototyped and experimented with \* The
590
590
# The result of running the planner is a static plan, which can be reused
591
591
# for sharding! This allows sharding to be static for production models
592
592
# instead of determining a new sharding plan everytime. Below, we use the
593
- # sharding plan to finally generate our ``ShardedEmbeddingBagCollection.``
593
+ # sharding plan to finally generate our ``ShardedEmbeddingBagCollection``.
594
594
#
595
595
596
596
# The static plan that was generated
616
616
# Remember that TorchRec is a highly optimized library for distributed
617
617
# embeddings. A concept that TorchRec introduces to enable higher
618
618
# performance for training on GPU is a
619
- # ``` LazyAwaitable`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__.
619
+ # `LazyAwaitable <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__.
620
620
# You will see ``LazyAwaitable`` types as outputs of various sharded
621
621
# TorchRec modules. All a ``LazyAwaitable`` does is delay calculating some
622
622
# result as long as possible, and it does it by acting like an async type.
@@ -741,7 +741,7 @@ def _wait_impl(self) -> torch.Tensor:
741
741
# ``EmbeddingBagCollection`` to generate a
742
742
# ``ShardedEmbeddingBagCollection`` module. This workflow is fine, but
743
743
# typically when doing model parallel,
744
- # ``` DistributedModelParallel`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
744
+ # `DistributedModelParallel <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
745
745
# (DMP) is used as the standard interface. When wrapping your model (in
746
746
# our case ``ebc``), with DMP, the following will occur:
747
747
#
@@ -826,22 +826,20 @@ def _wait_impl(self) -> torch.Tensor:
826
826
# still need to manage an optimizer for the other parameters not
827
827
# associated with TorchRec embedding modules. To find the other
828
828
# parameters,
829
- # use\ ``in_backward_optimizer_filter(model.named_parameters())``.
830
- #
829
+ # use ``in_backward_optimizer_filter(model.named_parameters())``.
831
830
# Apply an optimizer to those parameters as you would a normal Torch
832
831
# optimizer and combine this and the ``model.fused_optimizer`` into one
833
832
# ``CombinedOptimizer`` that you can use in your training loop to
834
833
# ``zero_grad`` and ``step`` through.
835
834
#
836
835
# Let's add an optimizer to our ``EmbeddingBagCollection``
837
- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
836
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
838
837
#
839
838
# We will do this in two ways, which are equivalent, but give you options
840
- # depending on your preferences: 1. Passing optimizer kwargs through fused
841
- # parameters (fused\_params) in sharder 2. Through
842
- # ``apply_optimizer_in_backward`` Note: ``apply_optimizer_in_backward``
843
- # converts the optimizer parameters to ``fused_params`` to pass to the
844
- # ``TBE`` in the ``EmbeddingBagCollection``/``EmbeddingCollection``.
839
+ # depending on your preferences:
840
+ # 1. Passing optimizer kwargs through ``fused_params`` in sharder
841
+ # 2. Through ``apply_optimizer_in_backward``, which converts the optimizer
842
+ # parameters to ``fused_params`` to pass to the `TBE`` in the ``EmbeddingBagCollection`` or ``EmbeddingCollection``.
845
843
#
846
844
847
845
# Approach 1: passing optimizer kwargs through fused parameters
@@ -856,7 +854,7 @@ def _wait_impl(self) -> torch.Tensor:
856
854
"eps" : 0.002 ,
857
855
}
858
856
859
- # Initialize sharder with fused_params
857
+ # Initialize sharder with `` fused_params``
860
858
sharder_with_fused_params = EmbeddingBagCollectionSharder (fused_params = fused_params )
861
859
862
860
# We'll use same plan and unsharded EBC as before but this time with our new sharder
0 commit comments