@@ -764,37 +764,32 @@ def __init__(
764
764
self ._weight_init_mins : List [float ] = []
765
765
self ._weight_init_maxs : List [float ] = []
766
766
self ._num_embeddings : List [int ] = []
767
- self ._embedding_dims : List [int ] = []
768
767
self ._local_cols : List [int ] = []
769
- self ._row_offset : List [int ] = []
770
- self ._col_offset : List [int ] = []
771
768
self ._feature_table_map : List [int ] = []
772
769
self .table_name_to_count : Dict [str , int ] = {}
773
770
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
774
771
775
- for idx , table_config in enumerate (self ._config .embedding_tables ):
776
- self ._local_rows .append (table_config .local_rows )
777
- self ._weight_init_mins .append (table_config .get_weight_init_min ())
778
- self ._weight_init_maxs .append (table_config .get_weight_init_max ())
779
- self ._num_embeddings .append (table_config .num_embeddings )
780
- self ._embedding_dims .append (table_config .embedding_dim )
781
- self ._row_offset .append (
782
- table_config .local_metadata .shard_offsets [0 ]
783
- if table_config .local_metadata
784
- and len (table_config .local_metadata .shard_offsets ) > 0
785
- else 0
786
- )
787
- self ._col_offset .append (
788
- table_config .local_metadata .shard_offsets [1 ]
789
- if table_config .local_metadata
790
- and len (table_config .local_metadata .shard_offsets ) > 1
791
- else 0
792
- )
793
- self ._local_cols .append (table_config .local_cols )
794
- self ._feature_table_map .extend ([idx ] * table_config .num_features ())
795
- if table_config .name not in self .table_name_to_count :
796
- self .table_name_to_count [table_config .name ] = 0
797
- self .table_name_to_count [table_config .name ] += 1
772
+ # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
773
+ # `ShardedEmbeddingTable`.
774
+ for idx , config in enumerate (self ._config .embedding_tables ):
775
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
776
+ self ._local_rows .append (config .local_rows )
777
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
778
+ # `get_weight_init_min`.
779
+ self ._weight_init_mins .append (config .get_weight_init_min ())
780
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
781
+ # `get_weight_init_max`.
782
+ self ._weight_init_maxs .append (config .get_weight_init_max ())
783
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
784
+ # `num_embeddings`.
785
+ self ._num_embeddings .append (config .num_embeddings )
786
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
787
+ self ._local_cols .append (config .local_cols )
788
+ self ._feature_table_map .extend ([idx ] * config .num_features ())
789
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
790
+ if config .name not in self .table_name_to_count :
791
+ self .table_name_to_count [config .name ] = 0
792
+ self .table_name_to_count [config .name ] += 1
798
793
799
794
def init_parameters (self ) -> None :
800
795
# initialize embedding weights
@@ -1085,14 +1080,6 @@ def __init__(
1085
1080
weights_precision = weights_precision ,
1086
1081
device = device ,
1087
1082
table_names = [t .name for t in config .embedding_tables ],
1088
- embedding_shard_info = list (
1089
- zip (
1090
- self ._num_embeddings ,
1091
- self ._embedding_dims ,
1092
- self ._row_offset ,
1093
- self ._col_offset ,
1094
- )
1095
- ),
1096
1083
** fused_params ,
1097
1084
)
1098
1085
)
@@ -1229,39 +1216,34 @@ def __init__(
1229
1216
self ._weight_init_mins : List [float ] = []
1230
1217
self ._weight_init_maxs : List [float ] = []
1231
1218
self ._num_embeddings : List [int ] = []
1232
- self ._embedding_dims : List [int ] = []
1233
1219
self ._local_cols : List [int ] = []
1234
- self ._row_offset : List [int ] = []
1235
- self ._col_offset : List [int ] = []
1236
1220
self ._feature_table_map : List [int ] = []
1237
1221
self ._emb_names : List [str ] = []
1238
1222
self ._lengths_per_emb : List [int ] = []
1239
1223
self .table_name_to_count : Dict [str , int ] = {}
1240
1224
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1241
1225
1242
- for idx , table_config in enumerate (self ._config .embedding_tables ):
1243
- self ._local_rows .append (table_config .local_rows )
1244
- self ._weight_init_mins .append (table_config .get_weight_init_min ())
1245
- self ._weight_init_maxs .append (table_config .get_weight_init_max ())
1246
- self ._num_embeddings .append (table_config .num_embeddings )
1247
- self ._embedding_dims .append (table_config .embedding_dim )
1248
- self ._row_offset .append (
1249
- table_config .local_metadata .shard_offsets [0 ]
1250
- if table_config .local_metadata
1251
- and len (table_config .local_metadata .shard_offsets ) > 0
1252
- else 0
1253
- )
1254
- self ._col_offset .append (
1255
- table_config .local_metadata .shard_offsets [1 ]
1256
- if table_config .local_metadata
1257
- and len (table_config .local_metadata .shard_offsets ) > 1
1258
- else 0
1259
- )
1260
- self ._local_cols .append (table_config .local_cols )
1261
- self ._feature_table_map .extend ([idx ] * table_config .num_features ())
1262
- if table_config .name not in self .table_name_to_count :
1263
- self .table_name_to_count [table_config .name ] = 0
1264
- self .table_name_to_count [table_config .name ] += 1
1226
+ # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1227
+ # `ShardedEmbeddingTable`.
1228
+ for idx , config in enumerate (self ._config .embedding_tables ):
1229
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
1230
+ self ._local_rows .append (config .local_rows )
1231
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1232
+ # `get_weight_init_min`.
1233
+ self ._weight_init_mins .append (config .get_weight_init_min ())
1234
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1235
+ # `get_weight_init_max`.
1236
+ self ._weight_init_maxs .append (config .get_weight_init_max ())
1237
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1238
+ # `num_embeddings`.
1239
+ self ._num_embeddings .append (config .num_embeddings )
1240
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
1241
+ self ._local_cols .append (config .local_cols )
1242
+ self ._feature_table_map .extend ([idx ] * config .num_features ())
1243
+ # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
1244
+ if config .name not in self .table_name_to_count :
1245
+ self .table_name_to_count [config .name ] = 0
1246
+ self .table_name_to_count [config .name ] += 1
1265
1247
1266
1248
def init_parameters (self ) -> None :
1267
1249
# initialize embedding weights
@@ -1582,14 +1564,6 @@ def __init__(
1582
1564
weights_precision = weights_precision ,
1583
1565
device = device ,
1584
1566
table_names = [t .name for t in config .embedding_tables ],
1585
- embedding_shard_info = list (
1586
- zip (
1587
- self ._num_embeddings ,
1588
- self ._embedding_dims ,
1589
- self ._row_offset ,
1590
- self ._col_offset ,
1591
- )
1592
- ),
1593
1567
** fused_params ,
1594
1568
)
1595
1569
)
0 commit comments