Skip to content

Commit 7a6bce1

Browse files
MoFHekarhdong
authored andcommitted
[fix] Now demos of DE Keras Embedding are able to run normally.
1 parent 4aecc4c commit 7a6bce1

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
except:
1212
from tensorflow.keras.optimizers import Adam
1313

14+
import horovod.tensorflow as hvd
15+
1416
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #VERY IMPORTANT!
1517

1618
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
@@ -29,7 +31,6 @@
2931

3032
# optimal performance
3133
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
32-
tf.config.experimental.set_synchronous_execution(False)
3334

3435
flags.DEFINE_string('mode', 'train', 'Select the running mode: train or test.')
3536
flags.DEFINE_string('model_dir', 'model_dir',
@@ -181,7 +182,7 @@ def embedding_out_split(embedding_out_concat, input_split_dims):
181182
return embedding_out
182183

183184

184-
class ChannelEmbeddingLayers():
185+
class ChannelEmbeddingLayers(tf.keras.layers.Layer):
185186

186187
def __init__(self,
187188
name='',
@@ -191,6 +192,8 @@ def __init__(self,
191192
mpi_size=1,
192193
mpi_rank=0):
193194

195+
super(ChannelEmbeddingLayers, self).__init__()
196+
194197
self.gpu_device = ["GPU:0"]
195198
self.cpu_device = ["CPU:0"]
196199

@@ -227,6 +230,9 @@ def __init__(self,
227230
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
228231
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
229232

233+
def build(self, input_shape):
234+
super(ChannelEmbeddingLayers, self).build(input_shape)
235+
230236
def __call__(self, features_info):
231237
dense_inputs = []
232238
dense_input_dims = []
@@ -301,14 +307,14 @@ def __init__(self,
301307
self.user_embedding = ChannelEmbeddingLayers(
302308
name='user',
303309
dense_embedding_size=user_embedding_size,
304-
user_embedding_size=user_embedding_size * 2,
310+
sparse_embedding_size=user_embedding_size * 2,
305311
embedding_initializer=embedding_initializer,
306312
mpi_size=mpi_size,
307313
mpi_rank=mpi_rank)
308314
self.movie_embedding = ChannelEmbeddingLayers(
309315
name='movie',
310316
dense_embedding_size=movie_embedding_size,
311-
user_embedding_size=movie_embedding_size * 2,
317+
sparse_embedding_size=movie_embedding_size * 2,
312318
embedding_initializer=embedding_initializer,
313319
mpi_size=mpi_size,
314320
mpi_rank=mpi_rank)
@@ -339,13 +345,14 @@ def call(self, features):
339345
# Construct input layers
340346
for fea_name in features.keys():
341347
fea_info = feature_info_spec[fea_name]
342-
input_tensor = tf.keras.layers.Input(shape=(fea_info['dim'],),
343-
dtype=fea_info['dtype'],
344-
name=fea_name)
348+
input_tensor = features[fea_name]
349+
input_tensor = tf.keras.layers.Lambda(lambda x: x,
350+
name=fea_name)(input_tensor)
351+
input_tensor = tf.reshape(input_tensor, (-1, fea_info['dim']))
345352
fea_info['input_tensor'] = input_tensor
346353
if fea_info.__contains__('boundaries'):
347-
input_tensor = tf.raw_ops.Bucketize(input=input_tensor,
348-
boundaries=fea_info['boundaries'])
354+
input_tensor = Bucketize(
355+
boundaries=fea_info['boundaries'])(input_tensor)
349356
# To prepare for GPU table combined queries, use a prefix to distinguish different features in a table.
350357
if fea_info['ptype'] == 'normal_gpu':
351358
if fea_info['dtype'] == tf.int64:
@@ -361,13 +368,15 @@ def call(self, features):
361368
fea_info['pretreated_tensor'] = input_tensor
362369

363370
user_fea = ['user_id', 'user_gender', 'user_occupation_label']
371+
user_fea = [i for i in features.keys() if i in user_fea]
364372
user_fea_info = {
365373
key: value
366374
for key, value in feature_info_spec.items()
367375
if key in user_fea
368376
}
369377
user_latent = self.user_embedding(user_fea_info)
370378
movie_fea = ['movie_id', 'movie_genres', 'user_occupation_label']
379+
movie_fea = [i for i in features.keys() if i in movie_fea]
371380
movie_fea_info = {
372381
key: value
373382
for key, value in feature_info_spec.items()
@@ -382,7 +391,8 @@ def call(self, features):
382391

383392
bias = self.bias_net(latent)
384393
x = 0.2 * x + 0.8 * bias
385-
return x
394+
user_rating = tf.keras.layers.Lambda(lambda x: x, name='user_rating')(x)
395+
return {'user_rating': user_rating}
386396

387397

388398
def get_dataset(batch_size=1):
@@ -408,7 +418,10 @@ def get_dataset(batch_size=1):
408418
tf.cast(x["timestamp"] - 880000000, tf.int32),
409419
})
410420

411-
ratings = ds.map(lambda x: {"user_rating": x["user_rating"]})
421+
ratings = ds.map(lambda x: {
422+
"user_rating":
423+
tf.one_hot(tf.cast(x["user_rating"] - 1, dtype=tf.int64), 5)
424+
})
412425
dataset = tf.data.Dataset.zip((features, ratings))
413426
shuffled = dataset.shuffle(1_000_000,
414427
seed=2021,
@@ -551,8 +564,8 @@ def train():
551564
auc,
552565
])
553566

554-
if os.path.exists(FLAGS.model_dir):
555-
model.load_weights(FLAGS.model_dir)
567+
if os.path.exists(FLAGS.model_dir + '/variables'):
568+
model.load_weights(FLAGS.model_dir + '/variables/variables')
556569

557570
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir)
558571
save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])

demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def train():
129129
auc,
130130
])
131131

132-
if os.path.exists(FLAGS.model_dir):
133-
model.load_weights(FLAGS.model_dir)
132+
if os.path.exists(FLAGS.model_dir + '/variables'):
133+
model.load_weights(FLAGS.model_dir + '/variables/variables')
134134

135135
model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=FLAGS.steps_per_epoch)
136136

0 commit comments

Comments
 (0)