Skip to content

Commit 8581ea1

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Refactor EmbeddingSharding grouping logic (#2891)
Summary: Pull Request resolved: #2891 # context * previously we use a util function [`create_sharding_infos_by_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150-L229) to group the `sharding_info`s so that a sharded module can create an [`EmbeddingSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L601-L643) with grouped sharding_infos. * after recent refactoring [#2887](#2887) the `create_embedding_sharding` becomes a public API, it's also reasonable to promote create_sharding_infos_by_sharding as a public API (classmethod) so that user can subclass it and overrides it. * since "grouping" is more relevant to this function, we'll rename it as "create_grouped_sharding_infos". Reviewed By: dstaay-fb Differential Revision: D58221182 fbshipit-source-id: ef417dfa45728d701902ad7ea53c3dce81c9a95c
1 parent 1f44c75 commit 8581ea1

File tree

4 files changed

+265
-240
lines changed

4 files changed

+265
-240
lines changed

torchrec/distributed/embedding.py

Lines changed: 87 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -147,88 +147,6 @@ def get_ec_index_dedup() -> bool:
147147
return EC_INDEX_DEDUP
148148

149149

150-
def create_sharding_infos_by_sharding(
151-
module: EmbeddingCollectionInterface,
152-
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
153-
fused_params: Optional[Dict[str, Any]],
154-
) -> Dict[str, List[EmbeddingShardingInfo]]:
155-
156-
if fused_params is None:
157-
fused_params = {}
158-
159-
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {}
160-
# state_dict returns parameter.Tensor, which loses parameter level attributes
161-
parameter_by_name = dict(module.named_parameters())
162-
# QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there
163-
state_dict = module.state_dict()
164-
165-
for (
166-
config,
167-
embedding_names,
168-
) in zip(module.embedding_configs(), module.embedding_names_by_table()):
169-
table_name = config.name
170-
assert table_name in table_name_to_parameter_sharding
171-
172-
parameter_sharding = table_name_to_parameter_sharding[table_name]
173-
if parameter_sharding.compute_kernel not in [
174-
kernel.value for kernel in EmbeddingComputeKernel
175-
]:
176-
raise ValueError(
177-
f"Compute kernel not supported {parameter_sharding.compute_kernel}"
178-
)
179-
180-
param_name = "embeddings." + config.name + ".weight"
181-
assert param_name in parameter_by_name or param_name in state_dict
182-
param = parameter_by_name.get(param_name, state_dict[param_name])
183-
184-
if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos:
185-
sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = []
186-
187-
optimizer_params = getattr(param, "_optimizer_kwargs", [{}])
188-
optimizer_classes = getattr(param, "_optimizer_classes", [None])
189-
190-
assert (
191-
len(optimizer_classes) == 1 and len(optimizer_params) == 1
192-
), f"Only support 1 optimizer, given {len(optimizer_classes)}"
193-
194-
optimizer_class = optimizer_classes[0]
195-
optimizer_params = optimizer_params[0]
196-
if optimizer_class:
197-
optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type(
198-
optimizer_class
199-
)
200-
201-
per_table_fused_params = merge_fused_params(fused_params, optimizer_params)
202-
per_table_fused_params = add_params_from_parameter_sharding(
203-
per_table_fused_params, parameter_sharding
204-
)
205-
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
206-
207-
sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
208-
(
209-
EmbeddingShardingInfo(
210-
embedding_config=EmbeddingTableConfig(
211-
num_embeddings=config.num_embeddings,
212-
embedding_dim=config.embedding_dim,
213-
name=config.name,
214-
data_type=config.data_type,
215-
feature_names=copy.deepcopy(config.feature_names),
216-
pooling=PoolingType.NONE,
217-
is_weighted=False,
218-
has_feature_processor=False,
219-
embedding_names=embedding_names,
220-
weight_init_max=config.weight_init_max,
221-
weight_init_min=config.weight_init_min,
222-
),
223-
param_sharding=parameter_sharding,
224-
param=param,
225-
fused_params=per_table_fused_params,
226-
)
227-
)
228-
)
229-
return sharding_type_to_sharding_infos
230-
231-
232150
def create_sharding_infos_by_sharding_device_group(
233151
module: EmbeddingCollectionInterface,
234152
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
@@ -503,7 +421,7 @@ def __init__(
503421
self._output_dtensor: bool = env.output_dtensor
504422
# TODO get rid of get_ec_index_dedup global flag
505423
self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup()
506-
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
424+
sharding_type_to_sharding_infos = self.create_grouped_sharding_infos(
507425
module,
508426
table_name_to_parameter_sharding,
509427
fused_params,
@@ -597,6 +515,92 @@ def __init__(
597515
if module.device != torch.device("meta"):
598516
self.load_state_dict(module.state_dict())
599517

518+
@classmethod
519+
def create_grouped_sharding_infos(
520+
cls,
521+
module: EmbeddingCollectionInterface,
522+
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
523+
fused_params: Optional[Dict[str, Any]],
524+
) -> Dict[str, List[EmbeddingShardingInfo]]:
525+
"""
526+
convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to
527+
EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters
528+
"""
529+
if fused_params is None:
530+
fused_params = {}
531+
532+
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {}
533+
# state_dict returns parameter.Tensor, which loses parameter level attributes
534+
parameter_by_name = dict(module.named_parameters())
535+
# QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there
536+
state_dict = module.state_dict()
537+
538+
for (
539+
config,
540+
embedding_names,
541+
) in zip(module.embedding_configs(), module.embedding_names_by_table()):
542+
table_name = config.name
543+
assert table_name in table_name_to_parameter_sharding
544+
545+
parameter_sharding = table_name_to_parameter_sharding[table_name]
546+
if parameter_sharding.compute_kernel not in [
547+
kernel.value for kernel in EmbeddingComputeKernel
548+
]:
549+
raise ValueError(
550+
f"Compute kernel not supported {parameter_sharding.compute_kernel}"
551+
)
552+
553+
param_name = "embeddings." + config.name + ".weight"
554+
assert param_name in parameter_by_name or param_name in state_dict
555+
param = parameter_by_name.get(param_name, state_dict[param_name])
556+
557+
if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos:
558+
sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = []
559+
560+
optimizer_params = getattr(param, "_optimizer_kwargs", [{}])
561+
optimizer_classes = getattr(param, "_optimizer_classes", [None])
562+
563+
assert (
564+
len(optimizer_classes) == 1 and len(optimizer_params) == 1
565+
), f"Only support 1 optimizer, given {len(optimizer_classes)}"
566+
567+
optimizer_class = optimizer_classes[0]
568+
optimizer_params = optimizer_params[0]
569+
if optimizer_class:
570+
optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type(
571+
optimizer_class
572+
)
573+
574+
per_table_fused_params = merge_fused_params(fused_params, optimizer_params)
575+
per_table_fused_params = add_params_from_parameter_sharding(
576+
per_table_fused_params, parameter_sharding
577+
)
578+
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
579+
580+
sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
581+
(
582+
EmbeddingShardingInfo(
583+
embedding_config=EmbeddingTableConfig(
584+
num_embeddings=config.num_embeddings,
585+
embedding_dim=config.embedding_dim,
586+
name=config.name,
587+
data_type=config.data_type,
588+
feature_names=copy.deepcopy(config.feature_names),
589+
pooling=PoolingType.NONE,
590+
is_weighted=False,
591+
has_feature_processor=False,
592+
embedding_names=embedding_names,
593+
weight_init_max=config.weight_init_max,
594+
weight_init_min=config.weight_init_min,
595+
),
596+
param_sharding=parameter_sharding,
597+
param=param,
598+
fused_params=per_table_fused_params,
599+
)
600+
)
601+
)
602+
return sharding_type_to_sharding_infos
603+
600604
@classmethod
601605
def create_embedding_sharding(
602606
cls,

0 commit comments

Comments
 (0)