Skip to content

Commit 30319b0

Browse files
udpate
1 parent 9d20bac commit 30319b0

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,8 @@ cd recommenders-addons
165165

166166
# This script links project with TensorFlow dependency
167167
python configure.py
168-
169168
bazel build --enable_runfiles build_pip_pkg
170169
bazel-bin/build_pip_pkg artifacts
171-
172170
pip install artifacts/tensorflow_recommenders_addons-*.whl
173171
```
174172
#### GPU Support

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424

2525
def has_horovod() -> bool:
26-
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'PMI_RANK' in os.environ
26+
#return 'OMPI_COMM_WORLD_RANK' in os.environ or 'PMI_RANK' in os.environ
27+
return True
2728

2829

2930
def config():
@@ -275,6 +276,7 @@ def __init__(self,
275276
init_capacity=init_capacity,
276277
kv_creator=kv_creator_dense,
277278
short_file_name=True,
279+
restrict_policy=de.TimestampRestrictPolicy,
278280
)
279281

280282
kv_creator_sparse = get_kv_creator(mpi_size, mpi_rank, init_capacity,
@@ -290,6 +292,7 @@ def __init__(self,
290292
init_capacity=init_capacity,
291293
kv_creator=kv_creator_sparse,
292294
short_file_name=True,
295+
restrict_policy=de.TimestampRestrictPolicy,
293296
)
294297

295298
self.dnn = tf.keras.layers.Dense(
@@ -648,7 +651,7 @@ def train():
648651
# horovod callback is used to broadcast the value generated by initializer of rank0.
649652
hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(
650653
root_rank=0)
651-
callbacks_list = [hvd_opt_init_callback, ckpt_callback]
654+
callbacks_list = [hvd_opt_init_callback]#, ckpt_callback]
652655
else:
653656
callbacks_list = [ckpt_callback]
654657

@@ -664,6 +667,8 @@ def train():
664667
steps_per_epoch=FLAGS.steps_per_epoch,
665668
verbose=1 if get_rank() == 0 else 0)
666669

670+
print(model.user_embedding.sparse_embedding_layer.params.restrict_policy)
671+
667672
export_to_savedmodel(model, FLAGS.model_dir)
668673
export_for_serving(model, FLAGS.export_dir)
669674

tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ def _save_de_var(de_var, proc_size=1, proc_rank=0):
121121
proc_size=proc_size,
122122
proc_rank=proc_rank)
123123

124-
def _maybe_save_restrict_policy_params(var, proc_size=1, proc_rank=0):
125-
if not hasattr(var, "restrict_policy"):
124+
def _maybe_save_restrict_policy_params(de_var, proc_size=1, proc_rank=0):
125+
if not hasattr(de_var, "restrict_policy"):
126126
return
127-
if var.restrict_policy is not None:
128-
de_var = var.restrict_policy._restrict_var
127+
if de_var.restrict_policy is not None:
128+
de_var = de_var.restrict_policy._restrict_var
129129
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)
130130

131131
def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
@@ -136,7 +136,7 @@ def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
136136
continue
137137
de_var = var.params
138138
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)
139-
_maybe_save_restrict_policy_params(var, proc_size=proc_size, proc_rank=proc_rank)
139+
_maybe_save_restrict_policy_params(de_var, proc_size=proc_size, proc_rank=proc_rank)
140140

141141
if hvd is None:
142142
call_original_save_func()

0 commit comments

Comments
 (0)