Skip to content

Commit ed113b2

Browse files
committed
Address tutorial comments
1 parent 6ca1922 commit ed113b2

File tree

2 files changed

+79
-70
lines changed

2 files changed

+79
-70
lines changed

en-wordlist.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,6 @@ webp
620620
wsi
621621
wsis
622622
Meta's
623-
criteo
624623
RecSys
625624
TorchRec
626625
sharding

intermediate_source/torchrec_interactive_tutorial.py

Lines changed: 79 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88
and TorchRec, focusing on handling large embedding tables through distributed training and advanced optimizations.
99
1010
.. grid:: 2
11+
1112
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
1213
:class-card: card-prerequisites
1314
* Fundamentals of embeddings and their role in recommendation systems
1415
* How to set up TorchRec to manage and implement embeddings in PyTorch environments
1516
* Explore advanced techniques for distributing large embedding tables across multiple GPUs
17+
1618
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
1719
:class-card: card-prerequisites
20+
1821
* PyTorch v2.5 or later with CUDA 11.8 or later
1922
* Python 3.9 or later
2023
* FBGEMM <https://github.com/pytorch/fbgemm>
@@ -30,7 +33,6 @@
3033
# following dependencies:
3134
#
3235
# .. code-block:: sh
33-
3436
# !pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
3537
# !pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
3638
# !pip3 install torchmetrics==1.0.3
@@ -43,7 +45,7 @@
4345
# ~~~~~~~~~~
4446
#
4547
# When building recommendation systems, categorical features typically
46-
# have massive cardinality, posts, users, ads, etc.
48+
# have massive cardinality, posts, users, ads, and so on.
4749
#
4850
# In order to represent these entities and model these relationships,
4951
# **embeddings** are used. In machine learning, **embeddings are a vectors
@@ -67,21 +69,23 @@
6769
# The inputs to embedding tables represent embedding lookups to retrieve
6870
# the embedding for a specific index/row. In recommendation systems, such
6971
# as those used in Meta, unique IDs are not only used for specific users,
70-
# but also across entites like posts and ads to serve as lookup indices to
72+
# but also across entities like posts and ads to serve as lookup indices to
7173
# respective embedding tables!
7274
#
73-
# Embeddings are trained in RecSys through the following process: 1.
74-
# **Input/lookup indices are fed into the model, as unique IDs**. IDs are
75+
# Embeddings are trained in RecSys through the following process:
76+
# 1. **Input/lookup indices are fed into the model, as unique IDs**. IDs are
7577
# hashed to the total size of the embedding table to prevent issues when
76-
# the ID > # of rows 2. Embeddings are then retrieved and **pooled, such
77-
# as taking the sum or mean of the embeddings**. This is required as there
78-
# can be a variable # of embeddings per example while the model expects
79-
# consistent shapes. 3. The **embeddings are used in conjunction with the
80-
# rest of the model to produce a prediction**, such as `Click-Through Rate
78+
# the ID > # of rows
79+
# 2. Embeddings are then retrieved and **pooled, such as taking the sum or
80+
# mean of the embeddings**. This is required as there can be a variable # of
81+
# embeddings per example while the model expects consistent shapes.
82+
# 3. The **embeddings are used in conjunction with the rest of the model to
83+
# produce a prediction**, such as `Click-Through Rate
8184
# (CTR) <https://support.google.com/google-ads/answer/2615875?hl=en>`__
82-
# for an Ad. 4. The loss is calculated with the prediction and the label
85+
# for an Ad.
86+
# 4. The loss is calculated with the prediction and the label
8387
# for an example, and **all weights of the model are updated through
84-
# gradient descent and backpropogation, including the embedding weights**
88+
# gradient descent and backpropagation, including the embedding weights**
8589
# that were associated with the example.
8690
#
8791
# These embeddings are crucial for representing categorical features, such
@@ -91,7 +95,7 @@
9195
# about the technical details of using embedding tables in RecSys.
9296
#
9397
# This tutorial will introduce the concept of embeddings, showcase
94-
# TorchRec specific modules/datatypes, and depict how distributed training
98+
# TorchRec specific modules and datatypes, and depict how distributed training
9599
# works with TorchRec.
96100
#
97101

@@ -108,7 +112,7 @@
108112
#
109113
# ```torch.nn.EmbeddingBag`` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__:
110114
# Embedding table where forward pass returns embeddings that are then
111-
# pooled, i.e. sum or mean. Otherwise known as **Pooled Embeddings**
115+
# pooled, for example, sum or mean, otherwise known as **Pooled Embeddings**
112116
#
113117
# In this section, we will go over a very brief introduction with doing
114118
# embedding lookups through passing in indices into the table. Check out
@@ -121,7 +125,7 @@
121125
weights = torch.rand(num_embeddings, embedding_dim)
122126
print("Weights:", weights)
123127

124-
# Pass in pre generated weights just for example, typically weights are randomly initialized
128+
# Pass in pregenerated weights just for example, typically weights are randomly initialized
125129
embedding_collection = torch.nn.Embedding(
126130
num_embeddings, embedding_dim, _weight=weights
127131
)
@@ -146,14 +150,14 @@
146150
print(embeddings)
147151
print("Shape: ", embeddings.shape)
148152

149-
# nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above
153+
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
150154
pooled_embeddings = embedding_bag_collection(ids)
151155

152156
print("Embedding Bag Collection Results: ")
153157
print(pooled_embeddings)
154158
print("Shape: ", pooled_embeddings.shape)
155159

156-
# nn.EmbeddingBag is the same as nn.Embedding but just with pooling (mean, sum, etc.)
160+
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
157161
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
158162
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
159163

@@ -275,7 +279,7 @@
275279
#
276280
# TorchRec has distinct data types for input and output of its modules:
277281
# ``JaggedTensor``, ``KeyedJaggedTensor``, and ``KeyedTensor``. Now you
278-
# might ask, why create new datatypes to represent sparse features? To
282+
# might ask, why create new data types to represent sparse features? To
279283
# answer that question, we must understand how sparse features are
280284
# represented in code.
281285
#
@@ -287,7 +291,7 @@
287291
# that a user interacted with, and the embeddings retrieved would be a
288292
# semantic representation of those Ads. The tricky part of representing
289293
# these features in code is that in each input example, **the number of
290-
# IDs is variable**. 1 day a user might have interacted with only 1 ad
294+
# IDs is variable**. One day a user might have interacted with only 1 ad
291295
# while the next day they interact with 3.
292296
#
293297
# A simple representation is shown below, where we have a ``lengths``
@@ -304,7 +308,7 @@
304308

305309

306310
######################################################################
307-
# Let’s look at the offsets as well as what is contained in each Batch
311+
# Next, let's look at the offsets as well as what is contained in each batch
308312
#
309313

310314
# Lengths can be converted to offsets for easy indexing of values
@@ -394,24 +398,24 @@
394398
#
395399
# Remember, the main purpose of TorchRec is to provide primitives for
396400
# distributed embeddings. So far, we've only worked with embedding tables
397-
# on 1 device. This has been possible given how small the embedding tables
401+
# on a single device. This has been possible given how small the embedding tables
398402
# have been, but in a production setting this isn't generally the case.
399-
# Embedding tables often get massive, where 1 table can't fit on a single
403+
# Embedding tables often get massive, where one table can't fit on a single
400404
# GPU, creating the requirement for multiple devices and a distributed
401405
# environment
402406
#
403407
# In this section, we will explore setting up a distributed environment,
404408
# exactly how actual production training is done, and explore sharding
405409
# embedding tables, all with Torchrec.
406410
#
407-
# **This section will also only use 1 gpu, though it will be treated in a
411+
# **This section will also only use 1 GPU, though it will be treated in a
408412
# distributed fashion. This is only a limitation for training, as training
409-
# has a process per gpu. Inference does not run into this requirement**
413+
# has a process per GPU. Inference does not run into this requirement**
410414
#
411415

412-
# Here we set up our torch distributed environment
413-
# WARNING: You can only call this cell once, calling it again will cause an error
414-
# as you can only initialize the process group once
416+
# Here we set up our PyTorch distributed environment.
417+
# .. warning:: In Colab, you can only call this cell once, calling it again will cause an error
418+
# as you can only initialize the process group once
415419

416420
import os
417421

@@ -420,14 +424,13 @@
420424
# Set up environment variables for distributed training
421425
# RANK is which GPU we are on, default 0
422426
os.environ["RANK"] = "0"
423-
# How many devices in our "world", notebook can only handle 1 process
427+
# How many devices in our "world", colab notebook can only handle 1 process
424428
os.environ["WORLD_SIZE"] = "1"
425429
# Localhost as we are training locally
426430
os.environ["MASTER_ADDR"] = "localhost"
427431
# Port for distributed training
428432
os.environ["MASTER_PORT"] = "29500"
429433

430-
# Note - you will need a V100 or A100 to run tutorial as!
431434
# nccl backend is for GPUs, gloo is for CPUs
432435
dist.init_process_group(backend="gloo")
433436

@@ -447,7 +450,7 @@
447450
# are able to do magnitudes more floating point operations/s
448451
# (`FLOPs <https://en.wikipedia.org/wiki/FLOPS>`__) than CPU. However,
449452
# GPUs come with the limitation of scarce fast memory (HBM which is
450-
# analogous to RAM for CPU), typically ~10s of GBs.
453+
# analogous to RAM for CPU), typically, ~10s of GBs.
451454
#
452455
# A RecSys model can contain embedding tables that far exceed the memory
453456
# limit for 1 GPU, hence the need for distribution of the embedding tables
@@ -469,25 +472,27 @@
469472
# known as “sharding”.
470473
#
471474
# There are many ways to shard embedding tables. The most common ways are:
472-
# \* Table-Wise: the table is placed entirely onto one device \*
473-
# Column-Wise: columns of embedding tables are sharded \* Row-Wise: rows
474-
# of embedding tables are sharded
475+
#
476+
# * Table-Wise: the table is placed entirely onto one device
477+
# * Column-Wise: columns of embedding tables are sharded
478+
# * Row-Wise: rows of embedding tables are sharded
475479
#
476480
# Sharded Modules
477481
# ~~~~~~~~~~~~~~~
478482
#
479483
# While all of this seems like a lot to deal with and implement, you're in
480484
# luck. **TorchRec provides all the primitives for easy distributed
481-
# training/inference**! In fact, TorchRec modules have two corresponding
485+
# training and inference**! In fact, TorchRec modules have two corresponding
482486
# classes for working with any TorchRec module in a distributed
483-
# environment: 1. The module sharder: This class exposes a ``shard`` API
484-
# that handles sharding a TorchRec Module, producing a sharded module. \*
485-
# For ``EmbeddingBagCollection``, the sharder is
487+
# environment:
488+
# 1. **The module sharder**: This class exposes a ``shard`` API
489+
# that handles sharding a TorchRec Module, producing a sharded module.
490+
# * For ``EmbeddingBagCollection``, the sharder is
486491
# ```EmbeddingBagCollectionSharder`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
487-
# 2. Sharded module: This class is a sharded variant of a TorchRec module.
492+
# 2. **Sharded module**: This class is a sharded variant of a TorchRec module.
488493
# It has the same input/output as a the regular TorchRec module, but much
489-
# more optimized and works in a distributed environment. \* For
490-
# ``EmbeddingBagCollection``, the sharded variant is
494+
# more optimized and works in a distributed environment.
495+
# * For ``EmbeddingBagCollection``, the sharded variant is
491496
# ```ShardedEmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
492497
#
493498
# Every TorchRec module has an unsharded and sharded variant. \* The
@@ -500,7 +505,6 @@
500505
# Parallelism, such as communication between GPUs for distributing
501506
# embeddings to the correct GPUs.
502507
#
503-
504508
# Refresher of our ``EmbeddingBagCollection`` module
505509
ebc
506510

@@ -527,30 +531,35 @@
527531
#
528532
# Given a number of embedding tables and a number of ranks, there are many
529533
# different sharding configurations that are possible. For example, given
530-
# 2 embedding tables and 2 GPUs, you can: \* Place 1 table on each GPU \*
531-
# Place both tables on a single GPU and no tables on the other \* Place
532-
# certain rows/columns on each GPU
534+
# 2 embedding tables and 2 GPUs, you can:
535+
#
536+
# * Place 1 table on each GPU
537+
# * Place both tables on a single GPU and no tables on the other
538+
# * Place certain rows and columns on each GPU
533539
#
534540
# Given all of these possibilities, we typically want a sharding
535541
# configuration that is optimal for performance.
536542
#
537543
# That is where the planner comes in. The planner is able to determine
538-
# given the # of embedding tables and the # of GPUs, what is the optimal
544+
# given the number of embedding tables and the number of GPUs, what is the optimal
539545
# configuration. Turns out, this is incredibly difficult to do manually,
540546
# with tons of factors that engineers have to consider to ensure an
541547
# optimal sharding plan. Luckily, TorchRec provides an auto planner when
542-
# the planner is used. The TorchRec planner: \* assesses memory
543-
# constraints of hardware, \* estimates compute based on memory fetches as
544-
# embedding lookups, \* addresses data specific factors, \* considers
545-
# other hardware specifics like bandwidth to generate an optimal sharding
546-
# plan.
548+
# the planner is used.
549+
#
550+
# The TorchRec planner:
551+
#
552+
# * Assesses memory constraints of hardware
553+
# * Estimates compute based on memory fetches as embedding lookups
554+
# * Addresses data specific factors
555+
# * Considers other hardware specifics like bandwidth to generate an optimal sharding plan
547556
#
548557
# In order to take into consideration all these variables, The TorchRec
549558
# planner can take in `various amounts of data for embedding tables,
550559
# constraints, hardware information, and
551560
# topology <https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155>`__
552561
# to aid in generating the optimal sharding plan for a model, which is
553-
# routinely provided across stacks
562+
# routinely provided across stacks.
554563
#
555564
# To learn more about sharding, see our `sharding
556565
# tutorial <https://pytorch.org/tutorials/advanced/sharding.html>`__.
@@ -574,8 +583,8 @@
574583
# Planner Result
575584
# ~~~~~~~~~~~~~~
576585
#
577-
# As you can see, when running the planner there is quite a bit of output
578-
# above. We can see a ton of stats being calculated along with where our
586+
# As you can see above, when running the planner there is quite a bit of output.
587+
# We can see a lot of stats being calculated along with where our
579588
# tables end up being placed.
580589
#
581590
# The result of running the planner is a static plan, which can be reused
@@ -602,7 +611,7 @@
602611

603612
######################################################################
604613
# ``Awaitable``
605-
# ^^^^^^^^^
614+
# ^^^^^^^^^^^^^^^^^^^^^^
606615
#
607616
# Remember that TorchRec is a highly optimized library for distributed
608617
# embeddings. A concept that TorchRec introduces to enable higher
@@ -618,7 +627,7 @@
618627
from torchrec.distributed.types import LazyAwaitable
619628

620629

621-
# Demonstrate a ``LazyAwaitable`` type
630+
# Demonstrate a ``LazyAwaitable`` type:
622631
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
623632
def __init__(self, size: List[int]) -> None:
624633
super().__init__()
@@ -663,25 +672,25 @@ def _wait_impl(self) -> torch.Tensor:
663672
# in training and inference. **Below are the three common APIs for
664673
# distributed training/inference** that are provided by TorchRec:
665674
#
666-
# 1. **input\_dist**: Handles distributing inputs from GPU to GPU
675+
# * ``input_dist``: Handles distributing inputs from GPU to GPU
667676
#
668-
# 2. **lookups**: Does the actual embedding lookup in an optimized,
669-
# batched manner using FBGEMM TBE (more on this later)
677+
# * ``lookups``: Does the actual embedding lookup in an optimized,
678+
# batched manner using FBGEMM TBE (more on this later).
670679
#
671-
# 3. **output\_dist**: Handles distributing outputs from GPU to GPU
680+
# * ``output_dist``: Handles distributing outputs from GPU to GPU
672681
#
673-
# The distribution of inputs/outputs is done through `NCCL
682+
# The distribution of inputs and outputs is done through `NCCL
674683
# Collectives <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html>`__,
675684
# namely
676685
# `All-to-Alls <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all>`__,
677-
# which is where all GPUs send/receive data to and from one another.
686+
# which is where all GPUs send and receive data to and from one another.
678687
# TorchRec interfaces with PyTorch distributed for collectives and
679688
# provides clean abstractions to the end users, removing the concern for
680689
# the lower level details.
681690
#
682691
# The backwards pass does all of these collectives but in the reverse
683-
# order for distribution of gradients. input\_dist, lookup, and
684-
# output\_dist all depend on the sharding scheme. Since we sharded in a
692+
# order for distribution of gradients. ``input_dist``, ``lookup``, and
693+
# ``output_dist`` all depend on the sharding scheme. Since we sharded in a
685694
# table-wise fashion, these APIs are modules that are constructed by
686695
# `TwPooledEmbeddingSharding <https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding>`__.
687696
#
@@ -1073,12 +1082,12 @@ def forward(self, kjt: KeyedJaggedTensor):
10731082

10741083

10751084
######################################################################
1076-
# Congrats!
1085+
# Conclusion
10771086
# ---------
10781087
#
1079-
# You have now gone from training a distributed RecSys model all the way
1080-
# to making it inference ready.
1081-
# https://github.com/pytorch/torchrec/tree/main/torchrec/inference has a
1088+
# In this tutorial, you have gone from training a distributed RecSys model all the way
1089+
# to making it inference ready. The `TorchRec repo
1090+
# <https://github.com/pytorch/torchrec/tree/main/torchrec/inference>`__ has a
10821091
# full example of how to load a TorchRec TorchScript model into C++ for
10831092
# inference.
10841093
#
@@ -1090,6 +1099,7 @@ def forward(self, kjt: KeyedJaggedTensor):
10901099
#
10911100
# For more information, please see our
10921101
# `dlrm <https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/>`__
1093-
# example, which includes multinode training on the criteo terabyte
1094-
# dataset, using Meta’s `DLRM <https://arxiv.org/abs/1906.00091>`__.
1102+
# example, which includes multinode training on the Criteo 1TB
1103+
# dataset using the methods described in `Deep Learning Recommendation Model
1104+
# for Personalization and Recommendation Systems <https://arxiv.org/abs/1906.00091>`__.
10951105
#

0 commit comments

Comments
 (0)