Skip to content

Commit c5166ef

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add comments in sharded EBC (#2855)
Summary: Pull Request resolved: #2855 # context * add comments for one of the main torchrec module (shardedEBC) Reviewed By: aporialiao Differential Revision: D72062170 fbshipit-source-id: e43d7c8dc569fb14db1376cd53f62f3509b81c3d
1 parent d1a9515 commit c5166ef

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def create_embedding_bag_sharding(
179179
) -> EmbeddingSharding[
180180
EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
181181
]:
182+
"""
183+
This is the main function to generate `EmbeddingSharding` instances based on sharding_type
184+
so that the same sharding_type in one EBC would be fused.
185+
"""
182186
sharding_type = sharding_infos[0].param_sharding.sharding_type
183187

184188
if device is not None and device.type == "meta":
@@ -240,6 +244,10 @@ def create_sharding_infos_by_sharding(
240244
fused_params: Optional[Dict[str, Any]],
241245
suffix: Optional[str] = "weight",
242246
) -> Dict[str, List[EmbeddingShardingInfo]]:
247+
"""
248+
convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to
249+
EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters
250+
"""
243251

244252
if fused_params is None:
245253
fused_params = {}
@@ -1197,6 +1205,9 @@ def _create_output_dist(self) -> None:
11971205
)
11981206

11991207
def _update_output_dist(self) -> None:
1208+
"""
1209+
This function is only used in update update_shards
1210+
"""
12001211
embedding_shard_metadata: List[Optional[ShardMetadata]] = []
12011212
# TODO: Optimize to only go through embedding shardings with new ranks
12021213
self._output_dists: List[nn.Module] = []
@@ -1252,6 +1263,10 @@ def input_dist(
12521263
ctx: EmbeddingBagCollectionContext,
12531264
features: Union[KeyedJaggedTensor, TensorDict],
12541265
) -> Awaitable[Awaitable[KJTList]]:
1266+
"""
1267+
This is the main API called in train_pipeline where we want to do the input_dist
1268+
in advance
1269+
"""
12551270
if isinstance(features, TensorDict):
12561271
feature_keys = list(features.keys()) # pyre-ignore[6]
12571272
if len(self._features_order) > 0:
@@ -1325,6 +1340,10 @@ def compute(
13251340
ctx: EmbeddingBagCollectionContext,
13261341
dist_input: KJTList,
13271342
) -> List[torch.Tensor]:
1343+
"""
1344+
this function is not used in general practice, it's only called by the base class
1345+
ShardedModule.compute_and_output_dist to do the basic function
1346+
"""
13281347
return [lookup(features) for lookup, features in zip(self._lookups, dist_input)]
13291348

13301349
def output_dist(
@@ -1377,6 +1396,10 @@ def output_dist(
13771396
def compute_and_output_dist(
13781397
self, ctx: EmbeddingBagCollectionContext, input: KJTList
13791398
) -> LazyAwaitable[KeyedTensor]:
1399+
"""
1400+
the main API called in PipelineForward, where the shardedEBC's forward is swapped
1401+
see _rewrite_model in train_pipeline for details
1402+
"""
13801403
batch_size_per_feature_pre_a2a = []
13811404
awaitables = []
13821405

@@ -1447,6 +1470,8 @@ def update_shards(
14471470
device: Optional[torch.device],
14481471
) -> None:
14491472
"""
1473+
This is the main API used in sharder.reshard, currently only support redistribution
1474+
of existing shards (across different ranks, ideally from hot ranks to cold ranks)
14501475
Update shards for this module based on the changed_sharding_params. This will:
14511476
1. Move current lookup tensors to CPU
14521477
2. Purge lookups

0 commit comments

Comments
 (0)