-
Notifications
You must be signed in to change notification settings - Fork 143
Description
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Linux 5.15.0-92-generic
102-Ubuntu SMP Wed Jan 10 09:33:48 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux - TensorFlow version and how it was installed (source or binary):
TF2.15.1 by pip3 install - TensorFlow-Recommenders-Addons version and how it was installed (source or binary):
tensorflow-recommenders-addons 0.7.0 by pip3 install - Python version: Python 3.11.13
- Is GPU used? (yes/no): NO
Describe the bug
A clear and concise description of what the bug is.
As the description in the title, when multi-features look-up the same dynamic-embedding table, the backprop breaks.
Why comes this:
In the recommendation system, the samples contain item_id, and user behavior sequence, which is also the item_id, so the technically they should retrieve from the same embedding table, and more, the user behavior sequence could be length more than 1, like 10,20,100 lenght
Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
in the following code, the user_id, item_id, and you_like look up from the same embedding table, and the you_like just simulates the user behavior sequence, it also could be a sequence instead of a 1-size length feature
import sys
import time
import tensorflow as tf
from tensorflow_recommenders_addons import dynamic_embedding as de
try:
from tensorflow.keras.legacy.optimizers import Adam
except:
from tensorflow.keras.optimizers import Adam
def create_sample_data(num_samples=1000):
data = {
"user_id": tf.random.uniform([num_samples], minval=1, maxval=100, dtype=tf.int64),
"item_id": tf.random.uniform([num_samples], minval=100, maxval=150, dtype=tf.int64),
"you_like": tf.random.uniform([num_samples], minval=100, maxval=150, dtype=tf.int64),
"label": tf.random.uniform([num_samples], maxval=2, dtype=tf.int32)
}
return tf.data.Dataset.from_tensor_slices(data).batch(32)
sample_dataset = create_sample_data()
for features in sample_dataset:
print(features)
class DualChannelsDeepModel(tf.keras.Model):
def __init__(self,
embedding_size=1,
embedding_initializer=None,
is_training=True):
if not is_training:
de.enable_inference_mode()
super(DualChannelsDeepModel, self).__init__()
if embedding_initializer is None:
embedding_initializer = tf.keras.initializers.Zeros()
self.de_embedding = de.keras.layers.BasicEmbedding(embedding_size=embedding_size, name='de_embedding')
self.dnn1 = tf.keras.layers.Dense(
64,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.dnn2 = tf.keras.layers.Dense(
16,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.dnn3 = tf.keras.layers.Dense(
1,
activation='sigmoid',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.bias_net = tf.keras.layers.Dense(
1,
activation=None,
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
@tf.function
def call(self, features):
user_id = tf.reshape(features['user_id'], (-1, 1))
movie_id = tf.reshape(features['item_id'], (-1, 1))
you_like = tf.reshape(features['you_like'], (-1, 1))
user_latent = self.de_embedding(user_id)
movie_latent = self.de_embedding(movie_id)
you_latent = self.de_embedding(you_like)
latent = tf.concat([user_latent, movie_latent, you_latent], axis=1)
x = self.dnn1(latent)
x = self.dnn2(x)
x = self.dnn3(x)
bias = self.bias_net(latent)
x = 0.8 * x + 0.2 * bias
return x
dnnModel = DualChannelsDeepModel(embedding_size=8, )
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
@tf.function
def train_step(features):
with tf.GradientTape() as tape:
logits = dnnModel(features)
y_true = tf.cast(features['label'], tf.float32)
y_pred = logits
# loss_start = time.time()
loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_true, y_pred))
gradients = tape.gradient(loss, dnnModel.trainable_variables)
optimizer.apply_gradients(zip(gradients, dnnModel.trainable_variables))
return loss
for batch_features in sample_dataset:
loss = train_step(batch_features)
print(f"[{time.time()}] Loss: {loss:.3f}")
Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
here is the error log:
2025-08-05 01:42:57.210571: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=8, init_size=8192
/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer RandomNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
warnings.warn(
2025-08-05 01:42:57.832756: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=8, init_size=8192
2025-08-05 01:42:57.836715: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=8, init_size=8192
2025-08-05 01:42:59.279244: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at resource_variable_ops.cc:1177 : INVALID_ARGUMENT: indices[21] = 21 is not in [0, 21)
2025-08-05 01:42:59.279323: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at resource_variable_ops.cc:1177 : INVALID_ARGUMENT: indices[21] = 21 is not in [0, 21)
Traceback (most recent call last):
File "/data/aha_train_tn/tf_mwm/tfra_de_test.py", line 121, in <module>
loss = train_step(batch_features)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node Adam/ResourceScatterAdd defined at (most recent call last):
File "/data/aha_train_tn/tf_mwm/tfra_de_test.py", line 121, in <module>
File "/data/aha_train_tn/tf_mwm/tfra_de_test.py", line 113, in train_step
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 752, in apply_gradients_strategy_v2
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 778, in apply_gradients_strategy_v2
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/keras/src/optimizers/optimizer.py", line 652, in apply_gradients
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/keras/src/optimizers/optimizer.py", line 1253, in _internal_apply_gradients
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 339, in _distributed_apply_gradients_fn
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 308, in apply_grad_to_update_var
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 311, in apply_grad_to_update_var
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 330, in apply_grad_to_update_var
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 303, in _update_step_fn
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 306, in _update_step_fn
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/keras/src/optimizers/optimizer.py", line 241, in _update_step
File "/home/work/miniconda3/envs/tfra2.15/lib/python3.11/site-packages/keras/src/optimizers/adam.py", line 179, in update_step
indices[21] = 21 is not in [0, 21)
[[{{node Adam/ResourceScatterAdd}}]] [Op:__inference_train_step_1958]