Skip to content

Commit 89d6ae0

Browse files
committed
Revert "Propagate proper embedding_shard_info when constructing TBE (#2876)"
This reverts commit 75f1f1c.
1 parent d93b0c7 commit 89d6ae0

File tree

1 file changed

+42
-68
lines changed

1 file changed

+42
-68
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 42 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -764,37 +764,32 @@ def __init__(
764764
self._weight_init_mins: List[float] = []
765765
self._weight_init_maxs: List[float] = []
766766
self._num_embeddings: List[int] = []
767-
self._embedding_dims: List[int] = []
768767
self._local_cols: List[int] = []
769-
self._row_offset: List[int] = []
770-
self._col_offset: List[int] = []
771768
self._feature_table_map: List[int] = []
772769
self.table_name_to_count: Dict[str, int] = {}
773770
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
774771

775-
for idx, table_config in enumerate(self._config.embedding_tables):
776-
self._local_rows.append(table_config.local_rows)
777-
self._weight_init_mins.append(table_config.get_weight_init_min())
778-
self._weight_init_maxs.append(table_config.get_weight_init_max())
779-
self._num_embeddings.append(table_config.num_embeddings)
780-
self._embedding_dims.append(table_config.embedding_dim)
781-
self._row_offset.append(
782-
table_config.local_metadata.shard_offsets[0]
783-
if table_config.local_metadata
784-
and len(table_config.local_metadata.shard_offsets) > 0
785-
else 0
786-
)
787-
self._col_offset.append(
788-
table_config.local_metadata.shard_offsets[1]
789-
if table_config.local_metadata
790-
and len(table_config.local_metadata.shard_offsets) > 1
791-
else 0
792-
)
793-
self._local_cols.append(table_config.local_cols)
794-
self._feature_table_map.extend([idx] * table_config.num_features())
795-
if table_config.name not in self.table_name_to_count:
796-
self.table_name_to_count[table_config.name] = 0
797-
self.table_name_to_count[table_config.name] += 1
772+
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
773+
# `ShardedEmbeddingTable`.
774+
for idx, config in enumerate(self._config.embedding_tables):
775+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
776+
self._local_rows.append(config.local_rows)
777+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
778+
# `get_weight_init_min`.
779+
self._weight_init_mins.append(config.get_weight_init_min())
780+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
781+
# `get_weight_init_max`.
782+
self._weight_init_maxs.append(config.get_weight_init_max())
783+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
784+
# `num_embeddings`.
785+
self._num_embeddings.append(config.num_embeddings)
786+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
787+
self._local_cols.append(config.local_cols)
788+
self._feature_table_map.extend([idx] * config.num_features())
789+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
790+
if config.name not in self.table_name_to_count:
791+
self.table_name_to_count[config.name] = 0
792+
self.table_name_to_count[config.name] += 1
798793

799794
def init_parameters(self) -> None:
800795
# initialize embedding weights
@@ -1085,14 +1080,6 @@ def __init__(
10851080
weights_precision=weights_precision,
10861081
device=device,
10871082
table_names=[t.name for t in config.embedding_tables],
1088-
embedding_shard_info=list(
1089-
zip(
1090-
self._num_embeddings,
1091-
self._embedding_dims,
1092-
self._row_offset,
1093-
self._col_offset,
1094-
)
1095-
),
10961083
**fused_params,
10971084
)
10981085
)
@@ -1229,39 +1216,34 @@ def __init__(
12291216
self._weight_init_mins: List[float] = []
12301217
self._weight_init_maxs: List[float] = []
12311218
self._num_embeddings: List[int] = []
1232-
self._embedding_dims: List[int] = []
12331219
self._local_cols: List[int] = []
1234-
self._row_offset: List[int] = []
1235-
self._col_offset: List[int] = []
12361220
self._feature_table_map: List[int] = []
12371221
self._emb_names: List[str] = []
12381222
self._lengths_per_emb: List[int] = []
12391223
self.table_name_to_count: Dict[str, int] = {}
12401224
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
12411225

1242-
for idx, table_config in enumerate(self._config.embedding_tables):
1243-
self._local_rows.append(table_config.local_rows)
1244-
self._weight_init_mins.append(table_config.get_weight_init_min())
1245-
self._weight_init_maxs.append(table_config.get_weight_init_max())
1246-
self._num_embeddings.append(table_config.num_embeddings)
1247-
self._embedding_dims.append(table_config.embedding_dim)
1248-
self._row_offset.append(
1249-
table_config.local_metadata.shard_offsets[0]
1250-
if table_config.local_metadata
1251-
and len(table_config.local_metadata.shard_offsets) > 0
1252-
else 0
1253-
)
1254-
self._col_offset.append(
1255-
table_config.local_metadata.shard_offsets[1]
1256-
if table_config.local_metadata
1257-
and len(table_config.local_metadata.shard_offsets) > 1
1258-
else 0
1259-
)
1260-
self._local_cols.append(table_config.local_cols)
1261-
self._feature_table_map.extend([idx] * table_config.num_features())
1262-
if table_config.name not in self.table_name_to_count:
1263-
self.table_name_to_count[table_config.name] = 0
1264-
self.table_name_to_count[table_config.name] += 1
1226+
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1227+
# `ShardedEmbeddingTable`.
1228+
for idx, config in enumerate(self._config.embedding_tables):
1229+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
1230+
self._local_rows.append(config.local_rows)
1231+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1232+
# `get_weight_init_min`.
1233+
self._weight_init_mins.append(config.get_weight_init_min())
1234+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1235+
# `get_weight_init_max`.
1236+
self._weight_init_maxs.append(config.get_weight_init_max())
1237+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1238+
# `num_embeddings`.
1239+
self._num_embeddings.append(config.num_embeddings)
1240+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
1241+
self._local_cols.append(config.local_cols)
1242+
self._feature_table_map.extend([idx] * config.num_features())
1243+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
1244+
if config.name not in self.table_name_to_count:
1245+
self.table_name_to_count[config.name] = 0
1246+
self.table_name_to_count[config.name] += 1
12651247

12661248
def init_parameters(self) -> None:
12671249
# initialize embedding weights
@@ -1582,14 +1564,6 @@ def __init__(
15821564
weights_precision=weights_precision,
15831565
device=device,
15841566
table_names=[t.name for t in config.embedding_tables],
1585-
embedding_shard_info=list(
1586-
zip(
1587-
self._num_embeddings,
1588-
self._embedding_dims,
1589-
self._row_offset,
1590-
self._col_offset,
1591-
)
1592-
),
15931567
**fused_params,
15941568
)
15951569
)

0 commit comments

Comments
 (0)