@@ -2,8 +2,8 @@ Introduction to Torchrec
2
2
====================================================
3
3
4
4
.. tip ::
5
- To get the most of this tutorial, we suggest using this
6
- `Colab Version <https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb >`__.
5
+ To get the most of this tutorial, we suggest using this
6
+ `Colab Version <https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb >`__.
7
7
This will allow you to experiment with the information presented below.
8
8
9
9
Frequently, when building recommendation systems, we want to represent
@@ -14,7 +14,7 @@ entities grow, the size of the embedding tables can exceed a single
14
14
GPU’s memory. A common practice is to shard the embedding table across
15
15
devices, a type of model parallelism. To that end, **torchRec introduces
16
16
its primary API
17
- called **\ `` ` DistributedModelParallel`` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__ \ **,
17
+ called ** | DistributedModelParallel |_ **,
18
18
or DMP. Like pytorch’s DistributedDataParallel, DMP wraps a model to
19
19
enable distributed training. **
20
20
@@ -28,21 +28,19 @@ We highly recommend CUDA when using torchRec. If using CUDA:
28
28
- cuda >= 11.0
29
29
30
30
31
- .. code :: python
31
+ .. code :: shell
32
32
33
- install pytorch with cudatoolkit 11.3
33
+ # install pytorch with cudatoolkit 11.3
34
34
conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
35
- install torchrec
35
+ # install torchrec
36
36
pip3 install torchrec-nightly
37
37
38
38
39
39
**Overview **
40
40
------------
41
41
42
- This tutorial will cover three pieces of torchRec - the ``nn.module ``
43
- ```EmbeddingBagCollection `` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__, the
44
- ```DistributedModelParallel `` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__ API, and
45
- the datastructure ```KeyedJaggedTensor `` <https://pytorch.org/torchrec/torchrec.sparse.html#torchrec.sparse.jagged_tensor.JaggedTensor>`__.
42
+ This tutorial will cover three pieces of torchRec - the ``nn.module `` |EmbeddingBagCollection |_, the |DistributedModelParallel |_ API, and
43
+ the datastructure |KeyedJaggedTensor |_.
46
44
47
45
48
46
Distributed Setup
77
75
From EmbeddingBag to EmbeddingBagCollection
78
76
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79
77
80
- Pytorch represents embeddings through
81
- ```torch.nn.Embedding `` <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__
82
- and
83
- ```torch.nn.EmbeddingBag `` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__.
78
+ Pytorch represents embeddings through |torch.nn.Embedding |_ and |torch.nn.EmbeddingBag |_.
84
79
EmbeddingBag is a pooled version of Embedding.
85
80
86
81
TorchRec extends these modules by creating collections of embeddings. We
87
- will use
88
- ```EmbeddingBagCollection `` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__
89
- to represent a group of EmbeddingBags.
82
+ will use |EmbeddingBagCollection |_ to represent a group of EmbeddingBags.
90
83
91
84
Here, we create an EmbeddingBagCollection (EBC) with two embedding bags.
92
85
Each table, ``product_table `` and ``user_table ``, is represented by 64
@@ -119,9 +112,7 @@ on device “meta”. This will tell EBC to not allocate memory yet.
119
112
DistributedModelParallel
120
113
~~~~~~~~~~~~~~~~~~~~~~~~
121
114
122
- Now, we’re ready to wrap our model with
123
- ```DistributedModelParallel `` <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
124
- (DMP). Instantiating DMP will:
115
+ Now, we’re ready to wrap our model with |DistributedModelParallel |_ (DMP). Instantiating DMP will:
125
116
126
117
1. Decide how to shard the model. DMP will collect the available
127
118
‘sharders’ and come up with a ‘plan’ of the optimal way to shard the
@@ -142,10 +133,7 @@ torchRec will place both on the single GPU.
142
133
Query vanilla nn.EmbeddingBag with input and offsets
143
134
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
144
135
145
- We query
146
- ```nn.Embedding `` <https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html>`__
147
- and
148
- ```nn.EmbeddingBag `` <https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html>`__
136
+ We query |nn.Embedding |_ and |nn.EmbeddingBag |_
149
137
with ``input `` and ``offsets ``. Input is a 1-D tensor containing the
150
138
lookup values. Offsets is a 1-D tensor where the sequence is a
151
139
cumulative sum of the number of values to pool per example.
@@ -174,8 +162,7 @@ Representing minibatches with KeyedJaggedTensor
174
162
We need an efficient representation of multiple examples of an arbitrary
175
163
number of entity IDs per feature per example. In order to enable this
176
164
“jagged” representation, we use the torchRec datastructure
177
- ```KeyedJaggedTensor `` <https://pytorch.org/torchrec/torchrec.sparse.html#torchrec.sparse.jagged_tensor.JaggedTensor>`__
178
- (KJT).
165
+ |KeyedJaggedTensor |_ (KJT).
179
166
180
167
Let’s take a look at **how to lookup a collection of two embedding
181
168
bags **, “product” and “user”. Assume the minibatch is made up of three
@@ -235,3 +222,17 @@ example, which includes multinode training on the criteo terabyte
235
222
dataset, using Meta’s `DLRM <https://arxiv.org/abs/1906.00091 >`__.
236
223
237
224
225
+ .. |DistributedModelParallel | replace :: ``DistributedModelParallel ``
226
+ .. _DistributedModelParallel : https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel
227
+ .. |EmbeddingBagCollection | replace :: ``EmbeddingBagCollection ``
228
+ .. _EmbeddingBagCollection : https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection
229
+ .. |KeyedJaggedTensor | replace :: ``KeyedJaggedTensor ``
230
+ .. _KeyedJaggedTensor : https://pytorch.org/torchrec/torchrec.sparse.html#torchrec.sparse.jagged_tensor.JaggedTensor
231
+ .. |torch.nn.Embedding | replace :: ``torch.nn.Embedding ``
232
+ .. _torch.nn.Embedding : https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
233
+ .. |torch.nn.EmbeddingBag | replace :: ``torch.nn.EmbeddingBag ``
234
+ .. _torch.nn.EmbeddingBag : https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
235
+ .. |nn.Embedding | replace :: ``nn.Embedding ``
236
+ .. _nn.Embedding : https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
237
+ .. |nn.EmbeddingBag | replace :: ``nn.EmbeddingBag ``
238
+ .. _nn.EmbeddingBag : https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
0 commit comments