Skip to content

Commit 7f239d8

Browse files
No public description
PiperOrigin-RevId: 663810871
1 parent bdd3a9b commit 7f239d8

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

official/recommendation/ranking/configs/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class ModelConfig(hyperparams.Config):
138138
max_ids_per_chip_per_sample: int | None = None
139139
max_ids_per_table: Union[int, List[int]] | None = None
140140
max_unique_ids_per_table: Union[int, List[int]] | None = None
141+
allow_id_dropping: bool = False
142+
initialize_tables_on_host: bool = False
141143

142144

143145
@dataclasses.dataclass

official/recommendation/ranking/data/data_pipeline_multi_hot.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,13 @@ def __init__(self,
4545
num_dense_features: int,
4646
vocab_sizes: List[int],
4747
multi_hot_sizes: List[int],
48-
use_synthetic_data: bool = False,
49-
use_cached_data: bool = False):
48+
use_synthetic_data: bool = False):
5049
self._file_pattern = file_pattern
5150
self._params = params
5251
self._num_dense_features = num_dense_features
5352
self._vocab_sizes = vocab_sizes
5453
self._use_synthetic_data = use_synthetic_data
5554
self._multi_hot_sizes = multi_hot_sizes
56-
self._use_cached_data = use_cached_data
5755

5856
def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
5957
params = self._params
@@ -146,7 +144,7 @@ def make_dataset(shard_index):
146144
num_parallel_calls=tf.data.experimental.AUTOTUNE)
147145

148146
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
149-
if self._use_cached_data:
147+
if self._params.use_cached_data:
150148
dataset = dataset.take(1).cache().repeat()
151149

152150
return dataset

official/recommendation/ranking/task.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def _get_tpu_embedding_feature_config(
3939
max_ids_per_chip_per_sample: Optional[int] = None,
4040
max_ids_per_table: Optional[Union[int, List[int]]] = None,
4141
max_unique_ids_per_table: Optional[Union[int, List[int]]] = None,
42+
allow_id_dropping: bool = False,
43+
initialize_tables_on_host: bool = False,
4244
) -> Tuple[
4345
Dict[str, tf.tpu.experimental.embedding.FeatureConfig],
4446
Optional[tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig],
@@ -57,6 +59,10 @@ def _get_tpu_embedding_feature_config(
5759
sample.
5860
max_ids_per_table: Maximum number of embedding ids per table.
5961
max_unique_ids_per_table: Maximum number of unique embedding ids per table.
62+
allow_id_dropping: bool to allow id dropping.
63+
initialize_tables_on_host: bool : if the embedding table size is more than
64+
what HBM can handle, this flag will help initialize the full embedding
65+
tables on host and then copy shards to HBM.
6066
6167
Returns:
6268
A dictionary of feature_name, FeatureConfig pairs.
@@ -140,7 +146,8 @@ def _get_tpu_embedding_feature_config(
140146
max_ids_per_chip_per_sample=max_ids_per_chip_per_sample,
141147
max_ids_per_table=max_ids_per_table_dict,
142148
max_unique_ids_per_table=max_unique_ids_per_table_dict,
143-
allow_id_dropping=False,
149+
allow_id_dropping=allow_id_dropping,
150+
initialize_tables_on_host=initialize_tables_on_host,
144151
)
145152

146153
return feature_config, sparsecore_config
@@ -248,6 +255,8 @@ def build_model(self) -> tf_keras.Model:
248255
max_ids_per_chip_per_sample=self.task_config.model.max_ids_per_chip_per_sample,
249256
max_ids_per_table=self.task_config.model.max_ids_per_table,
250257
max_unique_ids_per_table=self.task_config.model.max_unique_ids_per_table,
258+
allow_id_dropping=self.task_config.model.allow_id_dropping,
259+
initialize_tables_on_host=self.task_config.model.initialize_tables_on_host,
251260
)
252261
)
253262

0 commit comments

Comments
 (0)