Skip to content

Commit b55441a

Browse files
committed
Address tutorial comments2
1 parent ed113b2 commit b55441a

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

en-wordlist.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,4 +643,7 @@ nccl
643643
Localhost
644644
gpu
645645
torchmetrics
646-
url
646+
url
647+
colab
648+
sharders
649+
Criteo

intermediate_source/torchrec_interactive_tutorial.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
# about the technical details of using embedding tables in RecSys.
9696
#
9797
# This tutorial will introduce the concept of embeddings, showcase
98-
# TorchRec specific modules and datatypes, and depict how distributed training
98+
# TorchRec specific modules and data types, and depict how distributed training
9999
# works with TorchRec.
100100
#
101101

@@ -125,7 +125,7 @@
125125
weights = torch.rand(num_embeddings, embedding_dim)
126126
print("Weights:", weights)
127127

128-
# Pass in pregenerated weights just for example, typically weights are randomly initialized
128+
# Pass in pre-generated weights just for example, typically weights are randomly initialized
129129
embedding_collection = torch.nn.Embedding(
130130
num_embeddings, embedding_dim, _weight=weights
131131
)
@@ -193,14 +193,14 @@
193193
# primitives for large scale distributed embeddings**.
194194
#
195195
# From here on out, we will explore the major features of the TorchRec
196-
# library. We will start with torch.nn.Embedding and will extend that to
196+
# library. We will start with ``torch.nn.Embedding`` and will extend that to
197197
# custom TorchRec modules, explore distributed training environment with
198198
# generating a sharding plan for embeddings, look at inherent TorchRec
199199
# optimizations, and extend the model to be ready for inference in C++.
200200
# Below is a quick outline of what the journey will consist of - buckle
201201
# in!
202202
#
203-
# 1. TorchRec Modules and DataTypes
203+
# 1. TorchRec Modules and Data Types
204204
# 2. Distributed Training, Sharding, and Optimizations
205205
# 3. Inference
206206
#
@@ -210,7 +210,7 @@
210210

211211

212212
######################################################################
213-
# TorchRec Modules and Datatypes
213+
# TorchRec Modules and Data Types
214214
# ------------------------------
215215
#
216216
#
@@ -231,7 +231,7 @@
231231
# and
232232
# ```EmbeddingBagCollection`` <https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection>`__.
233233
# We will use ``EmbeddingBagCollection`` to represent a group of
234-
# EmbeddingBags.
234+
# embedding bags.
235235
#
236236
# Here, we create an ``EmbeddingBagCollection`` (EBC) with two embedding bags,
237237
# 1 representing **products** and 1 representing **users**. Each table,
@@ -262,7 +262,7 @@
262262

263263

264264
######################################################################
265-
# Let’s inspect the forward method for EmbeddingBagcollection and the
265+
# Let’s inspect the forward method for ``EmbeddingBagCollection`` and the
266266
# module’s inputs and outputs.
267267
#
268268

@@ -323,7 +323,7 @@
323323

324324
from torchrec import JaggedTensor
325325

326-
# JaggedTensor is just a wrapper around lengths/offsets and values tensors!
326+
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
327327
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
328328

329329
# Automatically compute offsets from lengths
@@ -332,7 +332,7 @@
332332
# Convert to list of values
333333
print("List of Values: ", jt.to_dense())
334334

335-
# __str__ representation
335+
# ``__str__`` representation
336336
print(jt)
337337

338338
from torchrec import KeyedJaggedTensor
@@ -358,15 +358,15 @@
358358
# Look at all values for ``KeyedJaggedTensor``
359359
print("Values: ", kjt.values())
360360

361-
# Can convert KJT to dictionary representation
361+
# Can convert ``KeyedJaggedTensor`` to dictionary representation
362362
print("to_dict: ", kjt.to_dict())
363363

364-
# ``KeyedJaggedTensor`` (KJT) string representation
364+
# ``KeyedJaggedTensor`` string representation
365365
print(kjt)
366366

367367
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
368368

369-
# Now we can run a forward pass on our ebc from before
369+
# Now we can run a forward pass on our ``EmbeddingBagCollection``` from before
370370
result = ebc(kjt)
371371
result
372372

@@ -375,7 +375,7 @@
375375

376376
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
377377
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
378-
# meaning emebddings for both features are of size 64. 64 + 64 = 128
378+
# meaning embeddings for both features are of size 64. 64 + 64 = 128
379379
print(result.values().shape)
380380

381381
# Nice to_dict method to determine the embeddings that belong to each feature
@@ -406,7 +406,7 @@
406406
#
407407
# In this section, we will explore setting up a distributed environment,
408408
# exactly how actual production training is done, and explore sharding
409-
# embedding tables, all with Torchrec.
409+
# embedding tables, all with TorchRec.
410410
#
411411
# **This section will also only use 1 GPU, though it will be treated in a
412412
# distributed fashion. This is only a limitation for training, as training
@@ -646,8 +646,8 @@ def _wait_impl(self) -> torch.Tensor:
646646
print(output)
647647

648648
kt = output.wait()
649-
# Now we have out KeyedTensor after calling .wait()
650-
# If you are confused as to why we have a KeyedTensor output,
649+
# Now we have our ``KeyedTensor`` after calling ``.wait()``
650+
# If you are confused as to why we have a ``KeyedTensor ``output,
651651
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
652652
print(type(kt))
653653

0 commit comments

Comments
 (0)