|
| 1 | +# Copyright 2025 The TensorFlow Recommenders Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Unified Embedding Module. |
| 16 | +
|
| 17 | +This module implements the feature multiplexing framework introduced in |
| 18 | +"Unified Embedding: Battle-Tested Feature Representations for Web-Scale ML |
| 19 | +Systems" by Coleman et al. using a regular embeding layer as the backend |
| 20 | +embedding table format. The intended usage is to construct a |
| 21 | +UnifiedEmbeddingConfig object that describes the number and size of the features |
| 22 | +to be embedded together in a shared table, and then construct a UnifiedEmbedding |
| 23 | +layer from the config. |
| 24 | +
|
| 25 | +Example: |
| 26 | +```python |
| 27 | +total_buckets = 100 # Number of embedding buckets, across |
| 28 | +num_tables = 3 # Number of backend tables to split the buckets across. |
| 29 | +table_dimension = 16 # Dimension of each unified embedding chunk. |
| 30 | +
|
| 31 | +# Construct the configuration object. |
| 32 | +embed_config = UnifiedEmbeddingConfig( |
| 33 | + buckets_per_table=total_buckets, |
| 34 | + dim_per_table=table_dimension, |
| 35 | + num_tables=num_tables, |
| 36 | + name="unified_table", |
| 37 | +) |
| 38 | +
|
| 39 | +# Add some features to the config, with different table sizes. |
| 40 | +embed_config.add_feature("movie_genre", 2) # 2 chunks = 32 dimensions. |
| 41 | +embed_config.add_feature("movie_id", 3) # 3 chunks = 48 dimensions. |
| 42 | +embed_config.add_feature("user_zip_code", 1) # 1 chunk = 16 dimensions. |
| 43 | +
|
| 44 | +# Construct the embedding layer, which takes a feature dict as input and |
| 45 | +# returns a list of embeddings, one for each feature in the config. |
| 46 | +embed_layer = UnifiedEmbedding(embed_config, embed_optimizer) |
| 47 | +``` |
| 48 | +""" |
| 49 | + |
| 50 | +from typing import Any, Dict, Union |
| 51 | + |
| 52 | +import tensorflow as tf |
| 53 | +from tensorflow_recommenders.layers.embedding import tpu_embedding_layer |
| 54 | + |
| 55 | + |
| 56 | +FeatureConfig = tf.tpu.experimental.embedding.FeatureConfig |
| 57 | +Hashing = tf.keras.layers.Hashing |
| 58 | +TableConfig = tf.tpu.experimental.embedding.TableConfig |
| 59 | +TPUEmbedding = tpu_embedding_layer.TPUEmbedding |
| 60 | +ValidTPUOptimizer = Union[ |
| 61 | + tf.tpu.experimental.embedding.SGD, |
| 62 | + tf.tpu.experimental.embedding.Adagrad, |
| 63 | + tf.tpu.experimental.embedding.Adam, |
| 64 | + tf.tpu.experimental.embedding.FTRL, |
| 65 | +] |
| 66 | + |
| 67 | + |
| 68 | +class UnifiedEmbeddingConfig: |
| 69 | + """Unified Embedding Config.""" |
| 70 | + |
| 71 | + def __init__( |
| 72 | + self, |
| 73 | + buckets_per_table: int, |
| 74 | + dim_per_table: int, |
| 75 | + num_tables: int, |
| 76 | + name: str, |
| 77 | + **kwargs, |
| 78 | + ): |
| 79 | + self._buckets_per_table = buckets_per_table |
| 80 | + self._dim_per_table = dim_per_table |
| 81 | + self._num_tables = num_tables |
| 82 | + self._current_table = 0 |
| 83 | + self._num_features = 0 |
| 84 | + self._name = name |
| 85 | + self._table_configs = [ |
| 86 | + TableConfig( |
| 87 | + vocabulary_size=self._buckets_per_table, |
| 88 | + dim=self._dim_per_table, |
| 89 | + name=f"{self._name}_{i}", |
| 90 | + **kwargs, |
| 91 | + ) |
| 92 | + for i in range(self._num_tables) |
| 93 | + ] |
| 94 | + # Store TPU embedding configs for each feature component (sub-feature). |
| 95 | + self._embed_configs = {} |
| 96 | + self._hashing_configs = {} |
| 97 | + |
| 98 | + def add_feature(self, name: str, num_chunks: int, **kwargs): |
| 99 | + """Add a categorical feature to the unified embedding config. |
| 100 | +
|
| 101 | + Arguments: |
| 102 | + name: Feature name, used to feed inputs from a feature dict and to track |
| 103 | + the sub-components of the embedding. |
| 104 | + num_chunks: Integer number of chunks to use for the embedding. The final |
| 105 | + dimension will be num_chunks * dim_per_table. |
| 106 | + **kwargs: Arguments to pass through to the underlying FeatureConfig. |
| 107 | + """ |
| 108 | + chunk_embed_configs = {} |
| 109 | + chunk_hashing_configs = {} |
| 110 | + for chunk_id in range(num_chunks): |
| 111 | + chunk_name = f"{self._name}_{name}_lookup_{chunk_id}" |
| 112 | + chunk_embed_config = FeatureConfig( |
| 113 | + table=self._table_configs[self._current_table], |
| 114 | + name=chunk_name, |
| 115 | + **kwargs, |
| 116 | + ) |
| 117 | + chunk_embed_configs[chunk_name] = chunk_embed_config |
| 118 | + chunk_hashing_configs[chunk_name] = { |
| 119 | + "num_bins": self._buckets_per_table, |
| 120 | + "salt": [self._num_features, chunk_id], |
| 121 | + } |
| 122 | + self._current_table += 1 |
| 123 | + self._current_table %= self._num_tables |
| 124 | + self._num_features += 1 |
| 125 | + self._embed_configs[name] = chunk_embed_configs |
| 126 | + self._hashing_configs[name] = chunk_hashing_configs |
| 127 | + |
| 128 | + @property |
| 129 | + def embedding_config(self): |
| 130 | + return self._embed_configs |
| 131 | + |
| 132 | + @property |
| 133 | + def hashing_config(self): |
| 134 | + return self._hashing_configs |
| 135 | + |
| 136 | + |
| 137 | +@tf.keras.utils.register_keras_serializable() |
| 138 | +class UnifiedEmbedding(tf.keras.layers.Layer): |
| 139 | + """Post-processing layer to concatenate unified embedding components.""" |
| 140 | + |
| 141 | + def __init__( |
| 142 | + self, |
| 143 | + config: UnifiedEmbeddingConfig, |
| 144 | + optimizer: ValidTPUOptimizer, |
| 145 | + **kwargs, |
| 146 | + ): |
| 147 | + super().__init__(**kwargs) |
| 148 | + if config.embedding_config: |
| 149 | + # Init is called with a blank config during serialization/deserialization. |
| 150 | + self._embedding_layer = TPUEmbedding( |
| 151 | + feature_config=config.embedding_config, |
| 152 | + optimizer=optimizer) |
| 153 | + |
| 154 | + self._hash_config = config.hashing_config |
| 155 | + self._hashing_layers = {} |
| 156 | + for name in self._hash_config: |
| 157 | + self._hashing_layers[name] = {} |
| 158 | + for component_name, component_params in self._hash_config[name].items(): |
| 159 | + self._hashing_layers[name][component_name] = Hashing(**component_params) |
| 160 | + |
| 161 | + def get_config(self): |
| 162 | + config = super().get_config() |
| 163 | + config.update({ |
| 164 | + "embed_layer": tf.keras.saving.serialize_keras_object( |
| 165 | + self._embedding_layer), |
| 166 | + "hash_config": self._hash_config, |
| 167 | + }) |
| 168 | + return config |
| 169 | + |
| 170 | + @classmethod |
| 171 | + def from_config(cls, config: Dict[str, Any]) -> "UnifiedEmbedding": |
| 172 | + # The only parameters we need to re-construct the layer are hashing_config, |
| 173 | + # to rebuild the Hashing layers, and the serialized embed_layer. For the |
| 174 | + # other arguments to the initializer, we use empty "dummy" values. |
| 175 | + ue_config = UnifiedEmbeddingConfig(0, 0, 0, "") |
| 176 | + ue_config.hashing_config = config.pop("hashing_config") |
| 177 | + ue_config.hashing_config = {} |
| 178 | + embed_layer = tf.keras.saving.deserialize_keras_object( |
| 179 | + config.pop("embed_layer")) |
| 180 | + config["config"] = ue_config |
| 181 | + config["optimizer"] = None # Optimizer is stored by the embed_layer. |
| 182 | + ue_layer = cls(**config) |
| 183 | + ue_layer._embedding_layer = embed_layer |
| 184 | + return ue_layer |
| 185 | + |
| 186 | + def call(self, features: Dict[str, tf.Tensor]): |
| 187 | + """Hash inputs, lookup embedding components, and concatenate the results. |
| 188 | +
|
| 189 | + Args: |
| 190 | + features: Input feature values as a {feature name: Tensor} dictionary. |
| 191 | + The dictionary keys must contain all of the feature names in the |
| 192 | + UnifiedEmbeddingConfig. The dictionary may also contain other features, |
| 193 | + but these will be ignored in the output. |
| 194 | +
|
| 195 | + Returns: |
| 196 | + A list of embeddings, sorted according to the order in which the features |
| 197 | + were added to the UnifiedEmbeddingConfig. |
| 198 | + """ |
| 199 | + # 1. Hash the features using different hash layers. |
| 200 | + hashed_features = {} |
| 201 | + for name, hashing_layers in self._hashing_layers.items(): |
| 202 | + hashed_features[name] = {} |
| 203 | + feature = features[name] |
| 204 | + for component_name, hashing_layer in hashing_layers.items(): |
| 205 | + hashed_features[name][component_name] = hashing_layer(feature) |
| 206 | + # 2. Embed the features using the embedding layer. |
| 207 | + embed_features = self._embedding_layer(hashed_features) |
| 208 | + # 3. Concatenate the sub-components of each feature (in order). |
| 209 | + output_features = [] |
| 210 | + for name in embed_features.keys(): |
| 211 | + components = embed_features[name] |
| 212 | + component_values = [components[k] for k in sorted(components.keys())] |
| 213 | + embedding = tf.concat(component_values, axis=-1) |
| 214 | + output_features.append(embedding) |
| 215 | + return output_features |
0 commit comments