@@ -147,88 +147,6 @@ def get_ec_index_dedup() -> bool:
147
147
return EC_INDEX_DEDUP
148
148
149
149
150
- def create_sharding_infos_by_sharding (
151
- module : EmbeddingCollectionInterface ,
152
- table_name_to_parameter_sharding : Dict [str , ParameterSharding ],
153
- fused_params : Optional [Dict [str , Any ]],
154
- ) -> Dict [str , List [EmbeddingShardingInfo ]]:
155
-
156
- if fused_params is None :
157
- fused_params = {}
158
-
159
- sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = {}
160
- # state_dict returns parameter.Tensor, which loses parameter level attributes
161
- parameter_by_name = dict (module .named_parameters ())
162
- # QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there
163
- state_dict = module .state_dict ()
164
-
165
- for (
166
- config ,
167
- embedding_names ,
168
- ) in zip (module .embedding_configs (), module .embedding_names_by_table ()):
169
- table_name = config .name
170
- assert table_name in table_name_to_parameter_sharding
171
-
172
- parameter_sharding = table_name_to_parameter_sharding [table_name ]
173
- if parameter_sharding .compute_kernel not in [
174
- kernel .value for kernel in EmbeddingComputeKernel
175
- ]:
176
- raise ValueError (
177
- f"Compute kernel not supported { parameter_sharding .compute_kernel } "
178
- )
179
-
180
- param_name = "embeddings." + config .name + ".weight"
181
- assert param_name in parameter_by_name or param_name in state_dict
182
- param = parameter_by_name .get (param_name , state_dict [param_name ])
183
-
184
- if parameter_sharding .sharding_type not in sharding_type_to_sharding_infos :
185
- sharding_type_to_sharding_infos [parameter_sharding .sharding_type ] = []
186
-
187
- optimizer_params = getattr (param , "_optimizer_kwargs" , [{}])
188
- optimizer_classes = getattr (param , "_optimizer_classes" , [None ])
189
-
190
- assert (
191
- len (optimizer_classes ) == 1 and len (optimizer_params ) == 1
192
- ), f"Only support 1 optimizer, given { len (optimizer_classes )} "
193
-
194
- optimizer_class = optimizer_classes [0 ]
195
- optimizer_params = optimizer_params [0 ]
196
- if optimizer_class :
197
- optimizer_params ["optimizer" ] = optimizer_type_to_emb_opt_type (
198
- optimizer_class
199
- )
200
-
201
- per_table_fused_params = merge_fused_params (fused_params , optimizer_params )
202
- per_table_fused_params = add_params_from_parameter_sharding (
203
- per_table_fused_params , parameter_sharding
204
- )
205
- per_table_fused_params = convert_to_fbgemm_types (per_table_fused_params )
206
-
207
- sharding_type_to_sharding_infos [parameter_sharding .sharding_type ].append (
208
- (
209
- EmbeddingShardingInfo (
210
- embedding_config = EmbeddingTableConfig (
211
- num_embeddings = config .num_embeddings ,
212
- embedding_dim = config .embedding_dim ,
213
- name = config .name ,
214
- data_type = config .data_type ,
215
- feature_names = copy .deepcopy (config .feature_names ),
216
- pooling = PoolingType .NONE ,
217
- is_weighted = False ,
218
- has_feature_processor = False ,
219
- embedding_names = embedding_names ,
220
- weight_init_max = config .weight_init_max ,
221
- weight_init_min = config .weight_init_min ,
222
- ),
223
- param_sharding = parameter_sharding ,
224
- param = param ,
225
- fused_params = per_table_fused_params ,
226
- )
227
- )
228
- )
229
- return sharding_type_to_sharding_infos
230
-
231
-
232
150
def create_sharding_infos_by_sharding_device_group (
233
151
module : EmbeddingCollectionInterface ,
234
152
table_name_to_parameter_sharding : Dict [str , ParameterSharding ],
@@ -503,7 +421,7 @@ def __init__(
503
421
self ._output_dtensor : bool = env .output_dtensor
504
422
# TODO get rid of get_ec_index_dedup global flag
505
423
self ._use_index_dedup : bool = use_index_dedup or get_ec_index_dedup ()
506
- sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
424
+ sharding_type_to_sharding_infos = self . create_grouped_sharding_infos (
507
425
module ,
508
426
table_name_to_parameter_sharding ,
509
427
fused_params ,
@@ -597,6 +515,92 @@ def __init__(
597
515
if module .device != torch .device ("meta" ):
598
516
self .load_state_dict (module .state_dict ())
599
517
518
+ @classmethod
519
+ def create_grouped_sharding_infos (
520
+ cls ,
521
+ module : EmbeddingCollectionInterface ,
522
+ table_name_to_parameter_sharding : Dict [str , ParameterSharding ],
523
+ fused_params : Optional [Dict [str , Any ]],
524
+ ) -> Dict [str , List [EmbeddingShardingInfo ]]:
525
+ """
526
+ convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to
527
+ EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters
528
+ """
529
+ if fused_params is None :
530
+ fused_params = {}
531
+
532
+ sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = {}
533
+ # state_dict returns parameter.Tensor, which loses parameter level attributes
534
+ parameter_by_name = dict (module .named_parameters ())
535
+ # QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there
536
+ state_dict = module .state_dict ()
537
+
538
+ for (
539
+ config ,
540
+ embedding_names ,
541
+ ) in zip (module .embedding_configs (), module .embedding_names_by_table ()):
542
+ table_name = config .name
543
+ assert table_name in table_name_to_parameter_sharding
544
+
545
+ parameter_sharding = table_name_to_parameter_sharding [table_name ]
546
+ if parameter_sharding .compute_kernel not in [
547
+ kernel .value for kernel in EmbeddingComputeKernel
548
+ ]:
549
+ raise ValueError (
550
+ f"Compute kernel not supported { parameter_sharding .compute_kernel } "
551
+ )
552
+
553
+ param_name = "embeddings." + config .name + ".weight"
554
+ assert param_name in parameter_by_name or param_name in state_dict
555
+ param = parameter_by_name .get (param_name , state_dict [param_name ])
556
+
557
+ if parameter_sharding .sharding_type not in sharding_type_to_sharding_infos :
558
+ sharding_type_to_sharding_infos [parameter_sharding .sharding_type ] = []
559
+
560
+ optimizer_params = getattr (param , "_optimizer_kwargs" , [{}])
561
+ optimizer_classes = getattr (param , "_optimizer_classes" , [None ])
562
+
563
+ assert (
564
+ len (optimizer_classes ) == 1 and len (optimizer_params ) == 1
565
+ ), f"Only support 1 optimizer, given { len (optimizer_classes )} "
566
+
567
+ optimizer_class = optimizer_classes [0 ]
568
+ optimizer_params = optimizer_params [0 ]
569
+ if optimizer_class :
570
+ optimizer_params ["optimizer" ] = optimizer_type_to_emb_opt_type (
571
+ optimizer_class
572
+ )
573
+
574
+ per_table_fused_params = merge_fused_params (fused_params , optimizer_params )
575
+ per_table_fused_params = add_params_from_parameter_sharding (
576
+ per_table_fused_params , parameter_sharding
577
+ )
578
+ per_table_fused_params = convert_to_fbgemm_types (per_table_fused_params )
579
+
580
+ sharding_type_to_sharding_infos [parameter_sharding .sharding_type ].append (
581
+ (
582
+ EmbeddingShardingInfo (
583
+ embedding_config = EmbeddingTableConfig (
584
+ num_embeddings = config .num_embeddings ,
585
+ embedding_dim = config .embedding_dim ,
586
+ name = config .name ,
587
+ data_type = config .data_type ,
588
+ feature_names = copy .deepcopy (config .feature_names ),
589
+ pooling = PoolingType .NONE ,
590
+ is_weighted = False ,
591
+ has_feature_processor = False ,
592
+ embedding_names = embedding_names ,
593
+ weight_init_max = config .weight_init_max ,
594
+ weight_init_min = config .weight_init_min ,
595
+ ),
596
+ param_sharding = parameter_sharding ,
597
+ param = param ,
598
+ fused_params = per_table_fused_params ,
599
+ )
600
+ )
601
+ )
602
+ return sharding_type_to_sharding_infos
603
+
600
604
@classmethod
601
605
def create_embedding_sharding (
602
606
cls ,
0 commit comments