@@ -179,6 +179,10 @@ def create_embedding_bag_sharding(
179
179
) -> EmbeddingSharding [
180
180
EmbeddingShardingContext , KeyedJaggedTensor , torch .Tensor , torch .Tensor
181
181
]:
182
+ """
183
+ This is the main function to generate `EmbeddingSharding` instances based on sharding_type
184
+ so that the same sharding_type in one EBC would be fused.
185
+ """
182
186
sharding_type = sharding_infos [0 ].param_sharding .sharding_type
183
187
184
188
if device is not None and device .type == "meta" :
@@ -240,6 +244,10 @@ def create_sharding_infos_by_sharding(
240
244
fused_params : Optional [Dict [str , Any ]],
241
245
suffix : Optional [str ] = "weight" ,
242
246
) -> Dict [str , List [EmbeddingShardingInfo ]]:
247
+ """
248
+ convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to
249
+ EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters
250
+ """
243
251
244
252
if fused_params is None :
245
253
fused_params = {}
@@ -1197,6 +1205,9 @@ def _create_output_dist(self) -> None:
1197
1205
)
1198
1206
1199
1207
def _update_output_dist (self ) -> None :
1208
+ """
1209
+ This function is only used in update update_shards
1210
+ """
1200
1211
embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
1201
1212
# TODO: Optimize to only go through embedding shardings with new ranks
1202
1213
self ._output_dists : List [nn .Module ] = []
@@ -1252,6 +1263,10 @@ def input_dist(
1252
1263
ctx : EmbeddingBagCollectionContext ,
1253
1264
features : Union [KeyedJaggedTensor , TensorDict ],
1254
1265
) -> Awaitable [Awaitable [KJTList ]]:
1266
+ """
1267
+ This is the main API called in train_pipeline where we want to do the input_dist
1268
+ in advance
1269
+ """
1255
1270
if isinstance (features , TensorDict ):
1256
1271
feature_keys = list (features .keys ()) # pyre-ignore[6]
1257
1272
if len (self ._features_order ) > 0 :
@@ -1325,6 +1340,10 @@ def compute(
1325
1340
ctx : EmbeddingBagCollectionContext ,
1326
1341
dist_input : KJTList ,
1327
1342
) -> List [torch .Tensor ]:
1343
+ """
1344
+ this function is not used in general practice, it's only called by the base class
1345
+ ShardedModule.compute_and_output_dist to do the basic function
1346
+ """
1328
1347
return [lookup (features ) for lookup , features in zip (self ._lookups , dist_input )]
1329
1348
1330
1349
def output_dist (
@@ -1377,6 +1396,10 @@ def output_dist(
1377
1396
def compute_and_output_dist (
1378
1397
self , ctx : EmbeddingBagCollectionContext , input : KJTList
1379
1398
) -> LazyAwaitable [KeyedTensor ]:
1399
+ """
1400
+ the main API called in PipelineForward, where the shardedEBC's forward is swapped
1401
+ see _rewrite_model in train_pipeline for details
1402
+ """
1380
1403
batch_size_per_feature_pre_a2a = []
1381
1404
awaitables = []
1382
1405
@@ -1447,6 +1470,8 @@ def update_shards(
1447
1470
device : Optional [torch .device ],
1448
1471
) -> None :
1449
1472
"""
1473
+ This is the main API used in sharder.reshard, currently only support redistribution
1474
+ of existing shards (across different ranks, ideally from hot ranks to cold ranks)
1450
1475
Update shards for this module based on the changed_sharding_params. This will:
1451
1476
1. Move current lookup tensors to CPU
1452
1477
2. Purge lookups
0 commit comments