Skip to content

Commit 67cbbc2

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
authored andcommitted
Enabling Optimizer checkpointing for KeyValueEmbeddingFusedOptimizer
Summary: **Context:** 1. We Introduced KeyValueEmbeddingFusedOptimizer for SSD Optimizer offloading https://www.internalfb.com/code/fbsource/[a43d796a4169]/fbcode/torchrec/distributed/batched_embedding_kernel.py?lines=341-376 But, currently the optimizer weights for SSD use-cases are not offloaded and are still on HBM Refer optimizer state dict CP: https://www.internalfb.com/code/fbsource/[6303aefbae20]/fbcode/minimal_viable_ai/core/model_family_api/optimizer.py?lines=1019-1028 Due to this, we want to initialize the optimizer class for SSD that allows us to get the latest optimizer weights values during checkpointing (get_optimizer_state call): https://www.internalfb.com/code/fbsource/[6303aefbae20]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py?lines=2540-2545 **Hence, In this Diff:** We have made the following changes: 1. Loop through every embedding table a. Change the table placement to CPU b. Create a ShardedTensor for embedding weight c. Create a ShardedTensor for optimizer weight --> There are three cases for optimizers --> Single Optimizer Value per Shard --> Row-wise Optimizer value per Shard --> Point-wise Optimizer value per Shard and then initialize the optimizer class with the appropriate parameters Differential Revision: D78131693
1 parent e3d5e36 commit 67cbbc2

File tree

2 files changed

+530
-3
lines changed

2 files changed

+530
-3
lines changed

0 commit comments

Comments
 (0)