8
8
and TorchRec, focusing on handling large embedding tables through distributed training and advanced optimizations.
9
9
10
10
.. grid:: 2
11
+
11
12
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12
13
:class-card: card-prerequisites
13
14
* Fundamentals of embeddings and their role in recommendation systems
14
15
* How to set up TorchRec to manage and implement embeddings in PyTorch environments
15
16
* Explore advanced techniques for distributing large embedding tables across multiple GPUs
17
+
16
18
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
17
19
:class-card: card-prerequisites
20
+
18
21
* PyTorch v2.5 or later with CUDA 11.8 or later
19
22
* Python 3.9 or later
20
23
* FBGEMM <https://github.com/pytorch/fbgemm>
30
33
# following dependencies:
31
34
#
32
35
# .. code-block:: sh
33
-
34
36
# !pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
35
37
# !pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
36
38
# !pip3 install torchmetrics==1.0.3
43
45
# ~~~~~~~~~~
44
46
#
45
47
# When building recommendation systems, categorical features typically
46
- # have massive cardinality, posts, users, ads, etc .
48
+ # have massive cardinality, posts, users, ads, and so on .
47
49
#
48
50
# In order to represent these entities and model these relationships,
49
51
# **embeddings** are used. In machine learning, **embeddings are a vectors
67
69
# The inputs to embedding tables represent embedding lookups to retrieve
68
70
# the embedding for a specific index/row. In recommendation systems, such
69
71
# as those used in Meta, unique IDs are not only used for specific users,
70
- # but also across entites like posts and ads to serve as lookup indices to
72
+ # but also across entities like posts and ads to serve as lookup indices to
71
73
# respective embedding tables!
72
74
#
73
- # Embeddings are trained in RecSys through the following process: 1.
74
- # **Input/lookup indices are fed into the model, as unique IDs**. IDs are
75
+ # Embeddings are trained in RecSys through the following process:
76
+ # 1. **Input/lookup indices are fed into the model, as unique IDs**. IDs are
75
77
# hashed to the total size of the embedding table to prevent issues when
76
- # the ID > # of rows 2. Embeddings are then retrieved and **pooled, such
77
- # as taking the sum or mean of the embeddings**. This is required as there
78
- # can be a variable # of embeddings per example while the model expects
79
- # consistent shapes. 3. The **embeddings are used in conjunction with the
80
- # rest of the model to produce a prediction**, such as `Click-Through Rate
78
+ # the ID > # of rows
79
+ # 2. Embeddings are then retrieved and **pooled, such as taking the sum or
80
+ # mean of the embeddings**. This is required as there can be a variable # of
81
+ # embeddings per example while the model expects consistent shapes.
82
+ # 3. The **embeddings are used in conjunction with the rest of the model to
83
+ # produce a prediction**, such as `Click-Through Rate
81
84
# (CTR) <https://support.google.com/google-ads/answer/2615875?hl=en>`__
82
- # for an Ad. 4. The loss is calculated with the prediction and the label
85
+ # for an Ad.
86
+ # 4. The loss is calculated with the prediction and the label
83
87
# for an example, and **all weights of the model are updated through
84
- # gradient descent and backpropogation , including the embedding weights**
88
+ # gradient descent and backpropagation , including the embedding weights**
85
89
# that were associated with the example.
86
90
#
87
91
# These embeddings are crucial for representing categorical features, such
91
95
# about the technical details of using embedding tables in RecSys.
92
96
#
93
97
# This tutorial will introduce the concept of embeddings, showcase
94
- # TorchRec specific modules/ datatypes, and depict how distributed training
98
+ # TorchRec specific modules and datatypes, and depict how distributed training
95
99
# works with TorchRec.
96
100
#
97
101
108
112
#
109
113
# ```torch.nn.EmbeddingBag`` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__:
110
114
# Embedding table where forward pass returns embeddings that are then
111
- # pooled, i.e. sum or mean. Otherwise known as **Pooled Embeddings**
115
+ # pooled, for example, sum or mean, otherwise known as **Pooled Embeddings**
112
116
#
113
117
# In this section, we will go over a very brief introduction with doing
114
118
# embedding lookups through passing in indices into the table. Check out
121
125
weights = torch .rand (num_embeddings , embedding_dim )
122
126
print ("Weights:" , weights )
123
127
124
- # Pass in pre generated weights just for example, typically weights are randomly initialized
128
+ # Pass in pregenerated weights just for example, typically weights are randomly initialized
125
129
embedding_collection = torch .nn .Embedding (
126
130
num_embeddings , embedding_dim , _weight = weights
127
131
)
146
150
print (embeddings )
147
151
print ("Shape: " , embeddings .shape )
148
152
149
- # nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above
153
+ # `` nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
150
154
pooled_embeddings = embedding_bag_collection (ids )
151
155
152
156
print ("Embedding Bag Collection Results: " )
153
157
print (pooled_embeddings )
154
158
print ("Shape: " , pooled_embeddings .shape )
155
159
156
- # nn.EmbeddingBag is the same as nn.Embedding but just with pooling (mean, sum, etc. )
160
+ # `` nn.EmbeddingBag`` is the same as `` nn.Embedding`` but just with pooling (mean, sum, and so on )
157
161
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
158
162
print ("Mean: " , torch .mean (embedding_collection (ids ), dim = 1 ))
159
163
275
279
#
276
280
# TorchRec has distinct data types for input and output of its modules:
277
281
# ``JaggedTensor``, ``KeyedJaggedTensor``, and ``KeyedTensor``. Now you
278
- # might ask, why create new datatypes to represent sparse features? To
282
+ # might ask, why create new data types to represent sparse features? To
279
283
# answer that question, we must understand how sparse features are
280
284
# represented in code.
281
285
#
287
291
# that a user interacted with, and the embeddings retrieved would be a
288
292
# semantic representation of those Ads. The tricky part of representing
289
293
# these features in code is that in each input example, **the number of
290
- # IDs is variable**. 1 day a user might have interacted with only 1 ad
294
+ # IDs is variable**. One day a user might have interacted with only 1 ad
291
295
# while the next day they interact with 3.
292
296
#
293
297
# A simple representation is shown below, where we have a ``lengths``
304
308
305
309
306
310
######################################################################
307
- # Let’ s look at the offsets as well as what is contained in each Batch
311
+ # Next, let' s look at the offsets as well as what is contained in each batch
308
312
#
309
313
310
314
# Lengths can be converted to offsets for easy indexing of values
394
398
#
395
399
# Remember, the main purpose of TorchRec is to provide primitives for
396
400
# distributed embeddings. So far, we've only worked with embedding tables
397
- # on 1 device. This has been possible given how small the embedding tables
401
+ # on a single device. This has been possible given how small the embedding tables
398
402
# have been, but in a production setting this isn't generally the case.
399
- # Embedding tables often get massive, where 1 table can't fit on a single
403
+ # Embedding tables often get massive, where one table can't fit on a single
400
404
# GPU, creating the requirement for multiple devices and a distributed
401
405
# environment
402
406
#
403
407
# In this section, we will explore setting up a distributed environment,
404
408
# exactly how actual production training is done, and explore sharding
405
409
# embedding tables, all with Torchrec.
406
410
#
407
- # **This section will also only use 1 gpu , though it will be treated in a
411
+ # **This section will also only use 1 GPU , though it will be treated in a
408
412
# distributed fashion. This is only a limitation for training, as training
409
- # has a process per gpu . Inference does not run into this requirement**
413
+ # has a process per GPU . Inference does not run into this requirement**
410
414
#
411
415
412
- # Here we set up our torch distributed environment
413
- # WARNING: You can only call this cell once, calling it again will cause an error
414
- # as you can only initialize the process group once
416
+ # Here we set up our PyTorch distributed environment.
417
+ # .. warning:: In Colab, you can only call this cell once, calling it again will cause an error
418
+ # as you can only initialize the process group once
415
419
416
420
import os
417
421
420
424
# Set up environment variables for distributed training
421
425
# RANK is which GPU we are on, default 0
422
426
os .environ ["RANK" ] = "0"
423
- # How many devices in our "world", notebook can only handle 1 process
427
+ # How many devices in our "world", colab notebook can only handle 1 process
424
428
os .environ ["WORLD_SIZE" ] = "1"
425
429
# Localhost as we are training locally
426
430
os .environ ["MASTER_ADDR" ] = "localhost"
427
431
# Port for distributed training
428
432
os .environ ["MASTER_PORT" ] = "29500"
429
433
430
- # Note - you will need a V100 or A100 to run tutorial as!
431
434
# nccl backend is for GPUs, gloo is for CPUs
432
435
dist .init_process_group (backend = "gloo" )
433
436
447
450
# are able to do magnitudes more floating point operations/s
448
451
# (`FLOPs <https://en.wikipedia.org/wiki/FLOPS>`__) than CPU. However,
449
452
# GPUs come with the limitation of scarce fast memory (HBM which is
450
- # analogous to RAM for CPU), typically ~10s of GBs.
453
+ # analogous to RAM for CPU), typically, ~10s of GBs.
451
454
#
452
455
# A RecSys model can contain embedding tables that far exceed the memory
453
456
# limit for 1 GPU, hence the need for distribution of the embedding tables
469
472
# known as “sharding”.
470
473
#
471
474
# There are many ways to shard embedding tables. The most common ways are:
472
- # \* Table-Wise: the table is placed entirely onto one device \*
473
- # Column-Wise: columns of embedding tables are sharded \* Row-Wise: rows
474
- # of embedding tables are sharded
475
+ #
476
+ # * Table-Wise: the table is placed entirely onto one device
477
+ # * Column-Wise: columns of embedding tables are sharded
478
+ # * Row-Wise: rows of embedding tables are sharded
475
479
#
476
480
# Sharded Modules
477
481
# ~~~~~~~~~~~~~~~
478
482
#
479
483
# While all of this seems like a lot to deal with and implement, you're in
480
484
# luck. **TorchRec provides all the primitives for easy distributed
481
- # training/ inference**! In fact, TorchRec modules have two corresponding
485
+ # training and inference**! In fact, TorchRec modules have two corresponding
482
486
# classes for working with any TorchRec module in a distributed
483
- # environment: 1. The module sharder: This class exposes a ``shard`` API
484
- # that handles sharding a TorchRec Module, producing a sharded module. \*
485
- # For ``EmbeddingBagCollection``, the sharder is
487
+ # environment:
488
+ # 1. **The module sharder**: This class exposes a ``shard`` API
489
+ # that handles sharding a TorchRec Module, producing a sharded module.
490
+ # * For ``EmbeddingBagCollection``, the sharder is
486
491
# ```EmbeddingBagCollectionSharder`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
487
- # 2. Sharded module: This class is a sharded variant of a TorchRec module.
492
+ # 2. ** Sharded module** : This class is a sharded variant of a TorchRec module.
488
493
# It has the same input/output as a the regular TorchRec module, but much
489
- # more optimized and works in a distributed environment. \* For
490
- # ``EmbeddingBagCollection``, the sharded variant is
494
+ # more optimized and works in a distributed environment.
495
+ # * For ``EmbeddingBagCollection``, the sharded variant is
491
496
# ```ShardedEmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
492
497
#
493
498
# Every TorchRec module has an unsharded and sharded variant. \* The
500
505
# Parallelism, such as communication between GPUs for distributing
501
506
# embeddings to the correct GPUs.
502
507
#
503
-
504
508
# Refresher of our ``EmbeddingBagCollection`` module
505
509
ebc
506
510
527
531
#
528
532
# Given a number of embedding tables and a number of ranks, there are many
529
533
# different sharding configurations that are possible. For example, given
530
- # 2 embedding tables and 2 GPUs, you can: \* Place 1 table on each GPU \*
531
- # Place both tables on a single GPU and no tables on the other \* Place
532
- # certain rows/columns on each GPU
534
+ # 2 embedding tables and 2 GPUs, you can:
535
+ #
536
+ # * Place 1 table on each GPU
537
+ # * Place both tables on a single GPU and no tables on the other
538
+ # * Place certain rows and columns on each GPU
533
539
#
534
540
# Given all of these possibilities, we typically want a sharding
535
541
# configuration that is optimal for performance.
536
542
#
537
543
# That is where the planner comes in. The planner is able to determine
538
- # given the # of embedding tables and the # of GPUs, what is the optimal
544
+ # given the number of embedding tables and the number of GPUs, what is the optimal
539
545
# configuration. Turns out, this is incredibly difficult to do manually,
540
546
# with tons of factors that engineers have to consider to ensure an
541
547
# optimal sharding plan. Luckily, TorchRec provides an auto planner when
542
- # the planner is used. The TorchRec planner: \* assesses memory
543
- # constraints of hardware, \* estimates compute based on memory fetches as
544
- # embedding lookups, \* addresses data specific factors, \* considers
545
- # other hardware specifics like bandwidth to generate an optimal sharding
546
- # plan.
548
+ # the planner is used.
549
+ #
550
+ # The TorchRec planner:
551
+ #
552
+ # * Assesses memory constraints of hardware
553
+ # * Estimates compute based on memory fetches as embedding lookups
554
+ # * Addresses data specific factors
555
+ # * Considers other hardware specifics like bandwidth to generate an optimal sharding plan
547
556
#
548
557
# In order to take into consideration all these variables, The TorchRec
549
558
# planner can take in `various amounts of data for embedding tables,
550
559
# constraints, hardware information, and
551
560
# topology <https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155>`__
552
561
# to aid in generating the optimal sharding plan for a model, which is
553
- # routinely provided across stacks
562
+ # routinely provided across stacks.
554
563
#
555
564
# To learn more about sharding, see our `sharding
556
565
# tutorial <https://pytorch.org/tutorials/advanced/sharding.html>`__.
574
583
# Planner Result
575
584
# ~~~~~~~~~~~~~~
576
585
#
577
- # As you can see, when running the planner there is quite a bit of output
578
- # above. We can see a ton of stats being calculated along with where our
586
+ # As you can see above , when running the planner there is quite a bit of output.
587
+ # We can see a lot of stats being calculated along with where our
579
588
# tables end up being placed.
580
589
#
581
590
# The result of running the planner is a static plan, which can be reused
602
611
603
612
######################################################################
604
613
# ``Awaitable``
605
- # ^^^^^^^^^
614
+ # ^^^^^^^^^^^^^^^^^^^^^^
606
615
#
607
616
# Remember that TorchRec is a highly optimized library for distributed
608
617
# embeddings. A concept that TorchRec introduces to enable higher
618
627
from torchrec .distributed .types import LazyAwaitable
619
628
620
629
621
- # Demonstrate a ``LazyAwaitable`` type
630
+ # Demonstrate a ``LazyAwaitable`` type:
622
631
class ExampleAwaitable (LazyAwaitable [torch .Tensor ]):
623
632
def __init__ (self , size : List [int ]) -> None :
624
633
super ().__init__ ()
@@ -663,25 +672,25 @@ def _wait_impl(self) -> torch.Tensor:
663
672
# in training and inference. **Below are the three common APIs for
664
673
# distributed training/inference** that are provided by TorchRec:
665
674
#
666
- # 1. **input\_dist** : Handles distributing inputs from GPU to GPU
675
+ # * ``input_dist`` : Handles distributing inputs from GPU to GPU
667
676
#
668
- # 2. ** lookups** : Does the actual embedding lookup in an optimized,
669
- # batched manner using FBGEMM TBE (more on this later)
677
+ # * `` lookups`` : Does the actual embedding lookup in an optimized,
678
+ # batched manner using FBGEMM TBE (more on this later).
670
679
#
671
- # 3. **output\_dist** : Handles distributing outputs from GPU to GPU
680
+ # * ``output_dist`` : Handles distributing outputs from GPU to GPU
672
681
#
673
- # The distribution of inputs/ outputs is done through `NCCL
682
+ # The distribution of inputs and outputs is done through `NCCL
674
683
# Collectives <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html>`__,
675
684
# namely
676
685
# `All-to-Alls <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all>`__,
677
- # which is where all GPUs send/ receive data to and from one another.
686
+ # which is where all GPUs send and receive data to and from one another.
678
687
# TorchRec interfaces with PyTorch distributed for collectives and
679
688
# provides clean abstractions to the end users, removing the concern for
680
689
# the lower level details.
681
690
#
682
691
# The backwards pass does all of these collectives but in the reverse
683
- # order for distribution of gradients. input\_dist, lookup, and
684
- # output\_dist all depend on the sharding scheme. Since we sharded in a
692
+ # order for distribution of gradients. ``input_dist``, `` lookup`` , and
693
+ # ``output_dist`` all depend on the sharding scheme. Since we sharded in a
685
694
# table-wise fashion, these APIs are modules that are constructed by
686
695
# `TwPooledEmbeddingSharding <https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding>`__.
687
696
#
@@ -1073,12 +1082,12 @@ def forward(self, kjt: KeyedJaggedTensor):
1073
1082
1074
1083
1075
1084
######################################################################
1076
- # Congrats!
1085
+ # Conclusion
1077
1086
# ---------
1078
1087
#
1079
- # You have now gone from training a distributed RecSys model all the way
1080
- # to making it inference ready.
1081
- # https://github.com/pytorch/torchrec/tree/main/torchrec/inference has a
1088
+ # In this tutorial, you have gone from training a distributed RecSys model all the way
1089
+ # to making it inference ready. The `TorchRec repo
1090
+ # < https://github.com/pytorch/torchrec/tree/main/torchrec/inference>`__ has a
1082
1091
# full example of how to load a TorchRec TorchScript model into C++ for
1083
1092
# inference.
1084
1093
#
@@ -1090,6 +1099,7 @@ def forward(self, kjt: KeyedJaggedTensor):
1090
1099
#
1091
1100
# For more information, please see our
1092
1101
# `dlrm <https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/>`__
1093
- # example, which includes multinode training on the criteo terabyte
1094
- # dataset, using Meta’s `DLRM <https://arxiv.org/abs/1906.00091>`__.
1102
+ # example, which includes multinode training on the Criteo 1TB
1103
+ # dataset using the methods described in `Deep Learning Recommendation Model
1104
+ # for Personalization and Recommendation Systems <https://arxiv.org/abs/1906.00091>`__.
1095
1105
#
0 commit comments