43
43
# ~~~~~~~~~~
44
44
#
45
45
# When building recommendation systems, categorical features typically
46
- # have massive cardinalities , posts, users, ads, etc.
46
+ # have massive cardinality , posts, users, ads, etc.
47
47
#
48
48
# In order to represent these entities and model these relationships,
49
49
# **embeddings** are used. In machine learning, **embeddings are a vectors
213
213
214
214
215
215
######################################################################
216
- # From EmbeddingBag to EmbeddingBagCollection
216
+ # From `` EmbeddingBag`` to `` EmbeddingBagCollection``
217
217
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
218
218
#
219
219
# We have already explored
229
229
# We will use ``EmbeddingBagCollection`` to represent a group of
230
230
# EmbeddingBags.
231
231
#
232
- # Here, we create an EmbeddingBagCollection (EBC) with two embedding bags,
232
+ # Here, we create an `` EmbeddingBagCollection`` (EBC) with two embedding bags,
233
233
# 1 representing **products** and 1 representing **users**. Each table,
234
234
# ``product_table`` and ``user_table``, is represented by 64 dimension
235
235
# embedding of size 4096.
264
264
265
265
import inspect
266
266
267
- # Let's look at the EmbeddingBagCollection forward method
268
- # What is a KeyedJaggedTensor and KeyedTensor?
267
+ # Let's look at the `` EmbeddingBagCollection`` forward method
268
+ # What is a `` KeyedJaggedTensor`` and `` KeyedTensor`` ?
269
269
print (inspect .getsource (ebc .forward ))
270
270
271
271
333
333
334
334
from torchrec import KeyedJaggedTensor
335
335
336
- # JaggedTensor represents IDs for 1 feature, but we have multiple features in an EmbeddingBagCollection
337
- # That's where KeyedJaggedTensor comes in! KeyedJaggedTensor is just multiple JaggedTensors for multiple id_list_feature_offsets
338
- # From before, we have our two features "product" and "user". Let's create JaggedTensors for both!
336
+ # `` JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an `` EmbeddingBagCollection``
337
+ # That's where `` KeyedJaggedTensor`` comes in! `` KeyedJaggedTensor`` is just multiple `` JaggedTensors`` for multiple id_list_feature_offsets
338
+ # From before, we have our two features "product" and "user". Let's create `` JaggedTensors`` for both!
339
339
340
340
product_jt = JaggedTensor (
341
341
values = torch .tensor ([1 , 2 , 1 , 5 ]), lengths = torch .tensor ([3 , 1 ])
345
345
# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
346
346
kjt = KeyedJaggedTensor .from_jt_dict ({"product" : product_jt , "user" : user_jt })
347
347
348
- # Look at our feature keys for the KeyedJaggedTensor
348
+ # Look at our feature keys for the `` KeyedJaggedTensor``
349
349
print ("Keys: " , kjt .keys ())
350
350
351
- # Look at the overall lengths for the KeyedJaggedTensor
351
+ # Look at the overall lengths for the `` KeyedJaggedTensor``
352
352
print ("Lengths: " , kjt .lengths ())
353
353
354
- # Look at all values for KeyedJaggedTensor
354
+ # Look at all values for `` KeyedJaggedTensor``
355
355
print ("Values: " , kjt .values ())
356
356
357
357
# Can convert KJT to dictionary representation
358
358
print ("to_dict: " , kjt .to_dict ())
359
359
360
- # KeyedJaggedTensor(KJT) string representation
360
+ # `` KeyedJaggedTensor`` (KJT) string representation
361
361
print (kjt )
362
362
363
- # Q2: What are the offsets for the KeyedJaggedTensor?
363
+ # Q2: What are the offsets for the `` KeyedJaggedTensor`` ?
364
364
365
365
# Now we can run a forward pass on our ebc from before
366
366
result = ebc (kjt )
367
367
result
368
368
369
- # Result is a KeyedTensor, which contains a list of the feature names and the embedding results
369
+ # Result is a `` KeyedTensor`` , which contains a list of the feature names and the embedding results
370
370
print (result .keys ())
371
371
372
372
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
373
- # 128 for dimension of embedding. If you look at where we initialized the EmbeddingBagCollection, we have two tables "product" and "user" of dimension 64 each
373
+ # 128 for dimension of embedding. If you look at where we initialized the `` EmbeddingBagCollection`` , we have two tables "product" and "user" of dimension 64 each
374
374
# meaning emebddings for both features are of size 64. 64 + 64 = 128
375
375
print (result .values ().shape )
376
376
392
392
# Now that we have a grasp on TorchRec modules and data types, it's time
393
393
# to take it to the next level.
394
394
#
395
- # Remember, TorchRec's main purpose is to provide primitives for
395
+ # Remember, the main purpose of TorchRec is to provide primitives for
396
396
# distributed embeddings. So far, we've only worked with embedding tables
397
397
# on 1 device. This has been possible given how small the embedding tables
398
398
# have been, but in a production setting this isn't generally the case.
420
420
# Set up environment variables for distributed training
421
421
# RANK is which GPU we are on, default 0
422
422
os .environ ["RANK" ] = "0"
423
- # How many devices in our "world", since Bento can only handle 1 process, 1 GPU
423
+ # How many devices in our "world", notebook can only handle 1 process
424
424
os .environ ["WORLD_SIZE" ] = "1"
425
425
# Localhost as we are training locally
426
426
os .environ ["MASTER_ADDR" ] = "localhost"
447
447
# are able to do magnitudes more floating point operations/s
448
448
# (`FLOPs <https://en.wikipedia.org/wiki/FLOPS>`__) than CPU. However,
449
449
# GPUs come with the limitation of scarce fast memory (HBM which is
450
- # analgous to RAM for CPU), typically ~10s of GBs.
450
+ # analogous to RAM for CPU), typically ~10s of GBs.
451
451
#
452
452
# A RecSys model can contain embedding tables that far exceed the memory
453
453
# limit for 1 GPU, hence the need for distribution of the embedding tables
496
496
# distributed training/inference.
497
497
#
498
498
# The sharded versions of TorchRec modules, for example
499
- # EmbeddingBagCollection, will handle everything that is needed for Model
499
+ # `` EmbeddingBagCollection`` , will handle everything that is needed for Model
500
500
# Parallelism, such as communication between GPUs for distributing
501
501
# embeddings to the correct GPUs.
502
502
#
503
503
504
- # Refresher of our EmbeddingBagCollection module
504
+ # Refresher of our `` EmbeddingBagCollection`` module
505
505
ebc
506
506
507
507
from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
508
508
from torchrec .distributed .planner import EmbeddingShardingPlanner , Topology
509
509
from torchrec .distributed .types import ShardingEnv
510
510
511
- # Corresponding sharder for EmbeddingBagCollection module
511
+ # Corresponding sharder for `` EmbeddingBagCollection`` module
512
512
sharder = EmbeddingBagCollectionSharder ()
513
513
514
- # ProcessGroup from torch.distributed initialized 2 cells above
514
+ # `` ProcessGroup`` from torch.distributed initialized 2 cells above
515
515
pg = dist .GroupMember .WORLD
516
516
assert pg is not None , "Process group is not initialized"
517
517
589
589
590
590
env = ShardingEnv .from_process_group (pg )
591
591
592
- # Shard the EmbeddingBagCollection module using the EmbeddingBagCollectionSharder
592
+ # Shard the `` EmbeddingBagCollection`` module using the `` EmbeddingBagCollectionSharder``
593
593
sharded_ebc = sharder .shard (ebc , plan .plan ["" ], env , torch .device ("cuda" ))
594
594
595
595
print (f"Sharded EBC Module: { sharded_ebc } " )
601
601
602
602
603
603
######################################################################
604
- # Awaitable
604
+ # `` Awaitable``
605
605
# ^^^^^^^^^
606
606
#
607
607
# Remember that TorchRec is a highly optimized library for distributed
618
618
from torchrec .distributed .types import LazyAwaitable
619
619
620
620
621
- # Demonstrate a LazyAwaitable type
621
+ # Demonstrate a `` LazyAwaitable`` type
622
622
class ExampleAwaitable (LazyAwaitable [torch .Tensor ]):
623
623
def __init__ (self , size : List [int ]) -> None :
624
624
super ().__init__ ()
@@ -633,20 +633,20 @@ def _wait_impl(self) -> torch.Tensor:
633
633
634
634
kjt = kjt .to ("cuda" )
635
635
output = sharded_ebc (kjt )
636
- # The output of our sharded EmbeddingBagCollection module is a an Awaitable?
636
+ # The output of our sharded `` EmbeddingBagCollection`` module is an ` Awaitable` ?
637
637
print (output )
638
638
639
639
kt = output .wait ()
640
640
# Now we have out KeyedTensor after calling .wait()
641
641
# If you are confused as to why we have a KeyedTensor output,
642
- # give yourself a refresher on the unsharded EmbeddingBagCollection module
642
+ # give yourself a refresher on the unsharded `` EmbeddingBagCollection`` module
643
643
print (type (kt ))
644
644
645
645
print (kt .keys ())
646
646
647
647
print (kt .values ().shape )
648
648
649
- # Same output format as unsharded EmbeddingBagCollection
649
+ # Same output format as unsharded `` EmbeddingBagCollection``
650
650
result_dict = kt .to_dict ()
651
651
for key , embedding in result_dict .items ():
652
652
print (key , embedding .shape )
@@ -656,7 +656,7 @@ def _wait_impl(self) -> torch.Tensor:
656
656
# Anatomy of Sharded TorchRec modules
657
657
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658
658
#
659
- # We have now successfully sharded an EmbeddingBagCollection given a
659
+ # We have now successfully sharded an `` EmbeddingBagCollection`` given a
660
660
# sharding plan that we generated! The sharded module has common APIs from
661
661
# TorchRec which abstract away distributed communication/compute amongst
662
662
# multiple GPUs. In fact, these APIs are highly optimized for performance
@@ -691,7 +691,7 @@ def _wait_impl(self) -> torch.Tensor:
691
691
# Distribute input KJTs to all other GPUs and receive KJTs
692
692
sharded_ebc ._input_dists
693
693
694
- # Distribute output embeddingts to all other GPUs and receive embeddings
694
+ # Distribute output embeddings to all other GPUs and receive embeddings
695
695
sharded_ebc ._output_dists
696
696
697
697
@@ -702,11 +702,11 @@ def _wait_impl(self) -> torch.Tensor:
702
702
# In performing lookups for a collection of embedding tables, a trivial
703
703
# solution would be to iterate through all the ``nn.EmbeddingBags`` and do
704
704
# a lookup per table. This is exactly what the standard, unsharded
705
- # TorchRec's ``EmbeddingBagCollection`` does. However, while this solution
705
+ # ``EmbeddingBagCollection`` does. However, while this solution
706
706
# is simple, it is extremely slow.
707
707
#
708
708
# `FBGEMM <https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu>`__ is a
709
- # library that provides GPU operators (otherewise known as kernels) that
709
+ # library that provides GPU operators (otherwise known as kernels) that
710
710
# are very optimized. One of these operators is known as **Table Batched
711
711
# Embedding** (TBE), provides two major optimizations:
712
712
#
@@ -724,10 +724,10 @@ def _wait_impl(self) -> torch.Tensor:
724
724
725
725
726
726
######################################################################
727
- # DistributedModelParallel
727
+ # `` DistributedModelParallel``
728
728
# ~~~~~~~~~~~~~~~~~~~~~~~~
729
729
#
730
- # We have now explored sharding a single EmbeddingBagCollection! We were
730
+ # We have now explored sharding a single `` EmbeddingBagCollection`` ! We were
731
731
# able to take the ``EmbeddingBagCollectionSharder`` and use the unsharded
732
732
# ``EmbeddingBagCollection`` to generate a
733
733
# ``ShardedEmbeddingBagCollection`` module. This workflow is fine, but
@@ -738,14 +738,14 @@ def _wait_impl(self) -> torch.Tensor:
738
738
#
739
739
# 1. Decide how to shard the model. DMP will collect the available
740
740
# ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the
741
- # embedding table(s) (i.e, the EmbeddingBagCollection)
741
+ # embedding table(s) (i.e, the `` EmbeddingBagCollection`` )
742
742
# 2. Actually shard the model. This includes allocating memory for each
743
743
# embedding table on the appropriate device(s).
744
744
#
745
745
# DMP takes in everything that we've just experimented with, like a static
746
746
# sharding plan, a list of sharders, etc. However, it also has some nice
747
747
# defaults to seamlessly shard a TorchRec model. In this toy example,
748
- # since we have two EmbeddingTables and one GPU, TorchRec will place both
748
+ # since we have two embedding tables and one GPU, TorchRec will place both
749
749
# on the single GPU.
750
750
#
751
751
@@ -824,7 +824,7 @@ def _wait_impl(self) -> torch.Tensor:
824
824
# ``CombinedOptimizer`` that you can use in your training loop to
825
825
# ``zero_grad`` and ``step`` through.
826
826
#
827
- # Let's add an optimizer to our EmbeddingBagCollection
827
+ # Let's add an optimizer to our `` EmbeddingBagCollection``
828
828
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
829
829
#
830
830
# We will do this in two ways, which are equivalent, but give you options
@@ -847,13 +847,13 @@ def _wait_impl(self) -> torch.Tensor:
847
847
"eps" : 0.002 ,
848
848
}
849
849
850
- # Init sharder with fused_params
850
+ # Initialize sharder with fused_params
851
851
sharder_with_fused_params = EmbeddingBagCollectionSharder (fused_params = fused_params )
852
852
853
853
# We'll use same plan and unsharded EBC as before but this time with our new sharder
854
854
sharded_ebc_fused_params = sharder_with_fused_params .shard (ebc , plan .plan ["" ], env , torch .device ("cuda" ))
855
855
856
- # Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correclty .
856
+ # Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly .
857
857
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
858
858
print (f"Original Sharded EBC fused optimizer: { sharded_ebc .fused_optimizer } " )
859
859
print (f"Sharded EBC with fused parameters fused optimizer: { sharded_ebc_fused_params .fused_optimizer } " )
@@ -880,7 +880,7 @@ def _wait_impl(self) -> torch.Tensor:
880
880
print (type (sharded_ebc_apply_opt .fused_optimizer ))
881
881
882
882
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
883
- # Pratically , just non TorchRec module parameters. Since our module is just a TorchRec EBC
883
+ # Practically , just non TorchRec module parameters. Since our module is just a TorchRec EBC
884
884
# there are no other parameters that aren't associated with TorchRec
885
885
print ("Non Fused Model Parameters:" )
886
886
print (dict (in_backward_optimizer_filter (sharded_ebc_fused_params .named_parameters ())).keys ())
@@ -972,7 +972,7 @@ def forward(self, kjt: KeyedJaggedTensor):
972
972
973
973
qconfig = QuantConfig (
974
974
# dtype of the result of the embedding lookup, post activation
975
- # torch.float generally for compatability with rest of the model
975
+ # torch.float generally for compatibility with rest of the model
976
976
# as rest of the model here usually isn't quantized
977
977
activation = quant .PlaceholderObserver .with_args (dtype = torch .float ),
978
978
# quantized type for embedding weights, aka parameters to actually quantize
0 commit comments