Skip to content

Commit e10e0c5

Browse files
author
TensorFlow Recommenders Authors
committed
Add Unified Embedding layer.
PiperOrigin-RevId: 812103382
1 parent 3f45506 commit e10e0c5

File tree

3 files changed

+404
-0
lines changed

3 files changed

+404
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2025 The TensorFlow Recommenders Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Feature Multiplexing layers."""
16+
17+
from tensorflow_recommenders.layers.feature_multiplexing.unified_embedding import UnifiedEmbedding
18+
from tensorflow_recommenders.layers.feature_multiplexing.unified_embedding import UnifiedEmbeddingConfig
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2025 The TensorFlow Recommenders Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unified Embedding Module.
16+
17+
This module implements the feature multiplexing framework introduced in
18+
"Unified Embedding: Battle-Tested Feature Representations for Web-Scale ML
19+
Systems" by Coleman et al. using a regular embeding layer as the backend
20+
embedding table format. The intended usage is to construct a
21+
UnifiedEmbeddingConfig object that describes the number and size of the features
22+
to be embedded together in a shared table, and then construct a UnifiedEmbedding
23+
layer from the config.
24+
25+
Example:
26+
```python
27+
total_buckets = 100 # Number of embedding buckets, across
28+
num_tables = 3 # Number of backend tables to split the buckets across.
29+
table_dimension = 16 # Dimension of each unified embedding chunk.
30+
31+
# Construct the configuration object.
32+
embed_config = UnifiedEmbeddingConfig(
33+
buckets_per_table=total_buckets,
34+
dim_per_table=table_dimension,
35+
num_tables=num_tables,
36+
name="unified_table",
37+
)
38+
39+
# Add some features to the config, with different table sizes.
40+
embed_config.add_feature("movie_genre", 2) # 2 chunks = 32 dimensions.
41+
embed_config.add_feature("movie_id", 3) # 3 chunks = 48 dimensions.
42+
embed_config.add_feature("user_zip_code", 1) # 1 chunk = 16 dimensions.
43+
44+
# Construct the embedding layer, which takes a feature dict as input and
45+
# returns a list of embeddings, one for each feature in the config.
46+
embed_layer = UnifiedEmbedding(embed_config, embed_optimizer)
47+
```
48+
"""
49+
50+
from typing import Any, Dict, Union
51+
52+
import tensorflow as tf
53+
from tensorflow_recommenders.layers.embedding import tpu_embedding_layer
54+
55+
56+
FeatureConfig = tf.tpu.experimental.embedding.FeatureConfig
57+
Hashing = tf.keras.layers.Hashing
58+
TableConfig = tf.tpu.experimental.embedding.TableConfig
59+
TPUEmbedding = tpu_embedding_layer.TPUEmbedding
60+
ValidTPUOptimizer = Union[
61+
tf.tpu.experimental.embedding.SGD,
62+
tf.tpu.experimental.embedding.Adagrad,
63+
tf.tpu.experimental.embedding.Adam,
64+
tf.tpu.experimental.embedding.FTRL,
65+
]
66+
67+
68+
class UnifiedEmbeddingConfig:
69+
"""Unified Embedding Config."""
70+
71+
def __init__(
72+
self,
73+
buckets_per_table: int,
74+
dim_per_table: int,
75+
num_tables: int,
76+
name: str,
77+
**kwargs,
78+
):
79+
self._buckets_per_table = buckets_per_table
80+
self._dim_per_table = dim_per_table
81+
self._num_tables = num_tables
82+
self._current_table = 0
83+
self._num_features = 0
84+
self._name = name
85+
self._table_configs = [
86+
TableConfig(
87+
vocabulary_size=self._buckets_per_table,
88+
dim=self._dim_per_table,
89+
name=f"{self._name}_{i}",
90+
**kwargs,
91+
)
92+
for i in range(self._num_tables)
93+
]
94+
# Store TPU embedding configs for each feature component (sub-feature).
95+
self._embed_configs = {}
96+
self._hashing_configs = {}
97+
98+
def add_feature(self, name: str, num_chunks: int, **kwargs):
99+
"""Add a categorical feature to the unified embedding config.
100+
101+
Arguments:
102+
name: Feature name, used to feed inputs from a feature dict and to track
103+
the sub-components of the embedding.
104+
num_chunks: Integer number of chunks to use for the embedding. The final
105+
dimension will be num_chunks * dim_per_table.
106+
**kwargs: Arguments to pass through to the underlying FeatureConfig.
107+
"""
108+
chunk_embed_configs = {}
109+
chunk_hashing_configs = {}
110+
for chunk_id in range(num_chunks):
111+
chunk_name = f"{self._name}_{name}_lookup_{chunk_id}"
112+
chunk_embed_config = FeatureConfig(
113+
table=self._table_configs[self._current_table],
114+
name=chunk_name,
115+
**kwargs,
116+
)
117+
chunk_embed_configs[chunk_name] = chunk_embed_config
118+
chunk_hashing_configs[chunk_name] = {
119+
"num_bins": self._buckets_per_table,
120+
"salt": [self._num_features, chunk_id],
121+
}
122+
self._current_table += 1
123+
self._current_table %= self._num_tables
124+
self._num_features += 1
125+
self._embed_configs[name] = chunk_embed_configs
126+
self._hashing_configs[name] = chunk_hashing_configs
127+
128+
@property
129+
def embedding_config(self):
130+
return self._embed_configs
131+
132+
@property
133+
def hashing_config(self):
134+
return self._hashing_configs
135+
136+
137+
@tf.keras.utils.register_keras_serializable()
138+
class UnifiedEmbedding(tf.keras.layers.Layer):
139+
"""Post-processing layer to concatenate unified embedding components."""
140+
141+
def __init__(
142+
self,
143+
config: UnifiedEmbeddingConfig,
144+
optimizer: ValidTPUOptimizer,
145+
**kwargs,
146+
):
147+
super().__init__(**kwargs)
148+
if config.embedding_config:
149+
# Init is called with a blank config during serialization/deserialization.
150+
self._embedding_layer = TPUEmbedding(
151+
feature_config=config.embedding_config,
152+
optimizer=optimizer)
153+
154+
self._hash_config = config.hashing_config
155+
self._hashing_layers = {}
156+
for name in self._hash_config:
157+
self._hashing_layers[name] = {}
158+
for component_name, component_params in self._hash_config[name].items():
159+
self._hashing_layers[name][component_name] = Hashing(**component_params)
160+
161+
def get_config(self):
162+
config = super().get_config()
163+
config.update({
164+
"embed_layer": tf.keras.saving.serialize_keras_object(
165+
self._embedding_layer),
166+
"hash_config": self._hash_config,
167+
})
168+
return config
169+
170+
@classmethod
171+
def from_config(cls, config: Dict[str, Any]) -> "UnifiedEmbedding":
172+
# The only parameters we need to re-construct the layer are hashing_config,
173+
# to rebuild the Hashing layers, and the serialized embed_layer. For the
174+
# other arguments to the initializer, we use empty "dummy" values.
175+
ue_config = UnifiedEmbeddingConfig(0, 0, 0, "")
176+
ue_config.hashing_config = config.pop("hashing_config")
177+
ue_config.hashing_config = {}
178+
embed_layer = tf.keras.saving.deserialize_keras_object(
179+
config.pop("embed_layer"))
180+
config["config"] = ue_config
181+
config["optimizer"] = None # Optimizer is stored by the embed_layer.
182+
ue_layer = cls(**config)
183+
ue_layer._embedding_layer = embed_layer
184+
return ue_layer
185+
186+
def call(self, features: Dict[str, tf.Tensor]):
187+
"""Hash inputs, lookup embedding components, and concatenate the results.
188+
189+
Args:
190+
features: Input feature values as a {feature name: Tensor} dictionary.
191+
The dictionary keys must contain all of the feature names in the
192+
UnifiedEmbeddingConfig. The dictionary may also contain other features,
193+
but these will be ignored in the output.
194+
195+
Returns:
196+
A list of embeddings, sorted according to the order in which the features
197+
were added to the UnifiedEmbeddingConfig.
198+
"""
199+
# 1. Hash the features using different hash layers.
200+
hashed_features = {}
201+
for name, hashing_layers in self._hashing_layers.items():
202+
hashed_features[name] = {}
203+
feature = features[name]
204+
for component_name, hashing_layer in hashing_layers.items():
205+
hashed_features[name][component_name] = hashing_layer(feature)
206+
# 2. Embed the features using the embedding layer.
207+
embed_features = self._embedding_layer(hashed_features)
208+
# 3. Concatenate the sub-components of each feature (in order).
209+
output_features = []
210+
for name in embed_features.keys():
211+
components = embed_features[name]
212+
component_values = [components[k] for k in sorted(components.keys())]
213+
embedding = tf.concat(component_values, axis=-1)
214+
output_features.append(embedding)
215+
return output_features

0 commit comments

Comments
 (0)