|
95 | 95 | # about the technical details of using embedding tables in RecSys.
|
96 | 96 | #
|
97 | 97 | # This tutorial will introduce the concept of embeddings, showcase
|
98 |
| -# TorchRec specific modules and datatypes, and depict how distributed training |
| 98 | +# TorchRec specific modules and data types, and depict how distributed training |
99 | 99 | # works with TorchRec.
|
100 | 100 | #
|
101 | 101 |
|
|
125 | 125 | weights = torch.rand(num_embeddings, embedding_dim)
|
126 | 126 | print("Weights:", weights)
|
127 | 127 |
|
128 |
| -# Pass in pregenerated weights just for example, typically weights are randomly initialized |
| 128 | +# Pass in pre-generated weights just for example, typically weights are randomly initialized |
129 | 129 | embedding_collection = torch.nn.Embedding(
|
130 | 130 | num_embeddings, embedding_dim, _weight=weights
|
131 | 131 | )
|
|
193 | 193 | # primitives for large scale distributed embeddings**.
|
194 | 194 | #
|
195 | 195 | # From here on out, we will explore the major features of the TorchRec
|
196 |
| -# library. We will start with torch.nn.Embedding and will extend that to |
| 196 | +# library. We will start with ``torch.nn.Embedding`` and will extend that to |
197 | 197 | # custom TorchRec modules, explore distributed training environment with
|
198 | 198 | # generating a sharding plan for embeddings, look at inherent TorchRec
|
199 | 199 | # optimizations, and extend the model to be ready for inference in C++.
|
200 | 200 | # Below is a quick outline of what the journey will consist of - buckle
|
201 | 201 | # in!
|
202 | 202 | #
|
203 |
| -# 1. TorchRec Modules and DataTypes |
| 203 | +# 1. TorchRec Modules and Data Types |
204 | 204 | # 2. Distributed Training, Sharding, and Optimizations
|
205 | 205 | # 3. Inference
|
206 | 206 | #
|
|
210 | 210 |
|
211 | 211 |
|
212 | 212 | ######################################################################
|
213 |
| -# TorchRec Modules and Datatypes |
| 213 | +# TorchRec Modules and Data Types |
214 | 214 | # ------------------------------
|
215 | 215 | #
|
216 | 216 | #
|
|
231 | 231 | # and
|
232 | 232 | # ```EmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__.
|
233 | 233 | # We will use ``EmbeddingBagCollection`` to represent a group of
|
234 |
| -# EmbeddingBags. |
| 234 | +# embedding bags. |
235 | 235 | #
|
236 | 236 | # Here, we create an ``EmbeddingBagCollection`` (EBC) with two embedding bags,
|
237 | 237 | # 1 representing **products** and 1 representing **users**. Each table,
|
|
262 | 262 |
|
263 | 263 |
|
264 | 264 | ######################################################################
|
265 |
| -# Let’s inspect the forward method for EmbeddingBagcollection and the |
| 265 | +# Let’s inspect the forward method for ``EmbeddingBagCollection`` and the |
266 | 266 | # module’s inputs and outputs.
|
267 | 267 | #
|
268 | 268 |
|
|
323 | 323 |
|
324 | 324 | from torchrec import JaggedTensor
|
325 | 325 |
|
326 |
| -# JaggedTensor is just a wrapper around lengths/offsets and values tensors! |
| 326 | +# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors! |
327 | 327 | jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
|
328 | 328 |
|
329 | 329 | # Automatically compute offsets from lengths
|
|
332 | 332 | # Convert to list of values
|
333 | 333 | print("List of Values: ", jt.to_dense())
|
334 | 334 |
|
335 |
| -# __str__ representation |
| 335 | +# ``__str__`` representation |
336 | 336 | print(jt)
|
337 | 337 |
|
338 | 338 | from torchrec import KeyedJaggedTensor
|
|
358 | 358 | # Look at all values for ``KeyedJaggedTensor``
|
359 | 359 | print("Values: ", kjt.values())
|
360 | 360 |
|
361 |
| -# Can convert KJT to dictionary representation |
| 361 | +# Can convert ``KeyedJaggedTensor`` to dictionary representation |
362 | 362 | print("to_dict: ", kjt.to_dict())
|
363 | 363 |
|
364 |
| -# ``KeyedJaggedTensor`` (KJT) string representation |
| 364 | +# ``KeyedJaggedTensor`` string representation |
365 | 365 | print(kjt)
|
366 | 366 |
|
367 | 367 | # Q2: What are the offsets for the ``KeyedJaggedTensor``?
|
368 | 368 |
|
369 |
| -# Now we can run a forward pass on our ebc from before |
| 369 | +# Now we can run a forward pass on our ``EmbeddingBagCollection``` from before |
370 | 370 | result = ebc(kjt)
|
371 | 371 | result
|
372 | 372 |
|
|
375 | 375 |
|
376 | 376 | # The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
|
377 | 377 | # 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
|
378 |
| -# meaning emebddings for both features are of size 64. 64 + 64 = 128 |
| 378 | +# meaning embeddings for both features are of size 64. 64 + 64 = 128 |
379 | 379 | print(result.values().shape)
|
380 | 380 |
|
381 | 381 | # Nice to_dict method to determine the embeddings that belong to each feature
|
|
406 | 406 | #
|
407 | 407 | # In this section, we will explore setting up a distributed environment,
|
408 | 408 | # exactly how actual production training is done, and explore sharding
|
409 |
| -# embedding tables, all with Torchrec. |
| 409 | +# embedding tables, all with TorchRec. |
410 | 410 | #
|
411 | 411 | # **This section will also only use 1 GPU, though it will be treated in a
|
412 | 412 | # distributed fashion. This is only a limitation for training, as training
|
@@ -646,8 +646,8 @@ def _wait_impl(self) -> torch.Tensor:
|
646 | 646 | print(output)
|
647 | 647 |
|
648 | 648 | kt = output.wait()
|
649 |
| -# Now we have out KeyedTensor after calling .wait() |
650 |
| -# If you are confused as to why we have a KeyedTensor output, |
| 649 | +# Now we have our ``KeyedTensor`` after calling ``.wait()`` |
| 650 | +# If you are confused as to why we have a ``KeyedTensor ``output, |
651 | 651 | # give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
|
652 | 652 | print(type(kt))
|
653 | 653 |
|
|
0 commit comments