Skip to content

Commit c56b686

Browse files
committed
Address tutorial comments3
1 parent b55441a commit c56b686

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

en-wordlist.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,4 +646,5 @@ torchmetrics
646646
url
647647
colab
648648
sharders
649-
Criteo
649+
Criteo
650+
torchrec

intermediate_source/torchrec_interactive_tutorial.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
###############################################
2929
# Install Dependencies
30-
# ================
30+
# ^^^^^^^^^^^^^^^^^^^^
3131
#
3232
# Before running this tutorial in Google Colab or other environment, install the
3333
# following dependencies:
@@ -106,11 +106,11 @@
106106
# Embeddings in PyTorch
107107
# ---------------------
108108
#
109-
# ```torch.nn.Embedding`` <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__:
109+
# `torch.nn.Embedding <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__:
110110
# Embedding table where forward pass returns the embeddings themselves as
111111
# is.
112112
#
113-
# ```torch.nn.EmbeddingBag`` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__:
113+
# `torch.nn.EmbeddingBag <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__:
114114
# Embedding table where forward pass returns embeddings that are then
115115
# pooled, for example, sum or mean, otherwise known as **Pooled Embeddings**
116116
#
@@ -173,7 +173,7 @@
173173

174174
######################################################################
175175
# TorchRec
176-
# ========
176+
# ^^^^^^^^
177177
#
178178
# Now you know how to use embedding tables, one of the foundations of
179179
# modern recommendation systems! These tables represent entities and
@@ -218,18 +218,18 @@
218218

219219
######################################################################
220220
# From ``EmbeddingBag`` to ``EmbeddingBagCollection``
221-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
221+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
222222
#
223223
# We have already explored
224-
# ```torch.nn.Embedding`` <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__
224+
# `torch.nn.Embedding <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__
225225
# and
226-
# ```torch.nn.EmbeddingBag`` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__.
226+
# `torch.nn.EmbeddingBag <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__.
227227
#
228228
# TorchRec extends these modules by creating collections of embeddings, in
229229
# other words modules that can have multiple embedding tables, with
230-
# ```EmbeddingCollection`` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection>`__
230+
# `EmbeddingCollection <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection>`__
231231
# and
232-
# ```EmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__.
232+
# `EmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__.
233233
# We will use ``EmbeddingBagCollection`` to represent a group of
234234
# embedding bags.
235235
#
@@ -346,7 +346,7 @@
346346
)
347347
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
348348

349-
# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
349+
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
350350
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
351351

352352
# Look at our feature keys for the ``KeyedJaggedTensor``
@@ -366,7 +366,7 @@
366366

367367
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
368368

369-
# Now we can run a forward pass on our ``EmbeddingBagCollection``` from before
369+
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
370370
result = ebc(kjt)
371371
result
372372

@@ -488,12 +488,12 @@
488488
# 1. **The module sharder**: This class exposes a ``shard`` API
489489
# that handles sharding a TorchRec Module, producing a sharded module.
490490
# * For ``EmbeddingBagCollection``, the sharder is
491-
# ```EmbeddingBagCollectionSharder`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
491+
# `EmbeddingBagCollectionSharder <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
492492
# 2. **Sharded module**: This class is a sharded variant of a TorchRec module.
493493
# It has the same input/output as a the regular TorchRec module, but much
494494
# more optimized and works in a distributed environment.
495495
# * For ``EmbeddingBagCollection``, the sharded variant is
496-
# ```ShardedEmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
496+
# `ShardedEmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
497497
#
498498
# Every TorchRec module has an unsharded and sharded variant. \* The
499499
# unsharded version is meant to be prototyped and experimented with \* The
@@ -590,7 +590,7 @@
590590
# The result of running the planner is a static plan, which can be reused
591591
# for sharding! This allows sharding to be static for production models
592592
# instead of determining a new sharding plan everytime. Below, we use the
593-
# sharding plan to finally generate our ``ShardedEmbeddingBagCollection.``
593+
# sharding plan to finally generate our ``ShardedEmbeddingBagCollection``.
594594
#
595595

596596
# The static plan that was generated
@@ -616,7 +616,7 @@
616616
# Remember that TorchRec is a highly optimized library for distributed
617617
# embeddings. A concept that TorchRec introduces to enable higher
618618
# performance for training on GPU is a
619-
# ```LazyAwaitable`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__.
619+
# `LazyAwaitable <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__.
620620
# You will see ``LazyAwaitable`` types as outputs of various sharded
621621
# TorchRec modules. All a ``LazyAwaitable`` does is delay calculating some
622622
# result as long as possible, and it does it by acting like an async type.
@@ -741,7 +741,7 @@ def _wait_impl(self) -> torch.Tensor:
741741
# ``EmbeddingBagCollection`` to generate a
742742
# ``ShardedEmbeddingBagCollection`` module. This workflow is fine, but
743743
# typically when doing model parallel,
744-
# ```DistributedModelParallel`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
744+
# `DistributedModelParallel <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
745745
# (DMP) is used as the standard interface. When wrapping your model (in
746746
# our case ``ebc``), with DMP, the following will occur:
747747
#
@@ -826,22 +826,20 @@ def _wait_impl(self) -> torch.Tensor:
826826
# still need to manage an optimizer for the other parameters not
827827
# associated with TorchRec embedding modules. To find the other
828828
# parameters,
829-
# use\ ``in_backward_optimizer_filter(model.named_parameters())``.
830-
#
829+
# use ``in_backward_optimizer_filter(model.named_parameters())``.
831830
# Apply an optimizer to those parameters as you would a normal Torch
832831
# optimizer and combine this and the ``model.fused_optimizer`` into one
833832
# ``CombinedOptimizer`` that you can use in your training loop to
834833
# ``zero_grad`` and ``step`` through.
835834
#
836835
# Let's add an optimizer to our ``EmbeddingBagCollection``
837-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
836+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
838837
#
839838
# We will do this in two ways, which are equivalent, but give you options
840-
# depending on your preferences: 1. Passing optimizer kwargs through fused
841-
# parameters (fused\_params) in sharder 2. Through
842-
# ``apply_optimizer_in_backward`` Note: ``apply_optimizer_in_backward``
843-
# converts the optimizer parameters to ``fused_params`` to pass to the
844-
# ``TBE`` in the ``EmbeddingBagCollection``/``EmbeddingCollection``.
839+
# depending on your preferences:
840+
# 1. Passing optimizer kwargs through ``fused_params`` in sharder
841+
# 2. Through ``apply_optimizer_in_backward``, which converts the optimizer
842+
# parameters to ``fused_params`` to pass to the `TBE`` in the ``EmbeddingBagCollection`` or ``EmbeddingCollection``.
845843
#
846844

847845
# Approach 1: passing optimizer kwargs through fused parameters
@@ -856,7 +854,7 @@ def _wait_impl(self) -> torch.Tensor:
856854
"eps": 0.002,
857855
}
858856

859-
# Initialize sharder with fused_params
857+
# Initialize sharder with ``fused_params``
860858
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
861859

862860
# We'll use same plan and unsharded EBC as before but this time with our new sharder

0 commit comments

Comments
 (0)