Skip to content

Commit 0d24b09

Browse files
update logic
1 parent 30319b0 commit 0d24b09

File tree

2 files changed

+27
-28
lines changed

2 files changed

+27
-28
lines changed

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tensorflow_datasets as tfds
2020
import horovod.tensorflow as hvd
2121
# optimal performance
22-
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
22+
# os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
2323

2424

2525
def has_horovod() -> bool:
@@ -665,9 +665,7 @@ def train():
665665
callbacks=callbacks_list,
666666
epochs=FLAGS.epochs,
667667
steps_per_epoch=FLAGS.steps_per_epoch,
668-
verbose=1 if get_rank() == 0 else 0)
669-
670-
print(model.user_embedding.sparse_embedding_layer.params.restrict_policy)
668+
verbose=1)# if get_rank() == 0 else 0)
671669

672670
export_to_savedmodel(model, FLAGS.model_dir)
673671
export_for_serving(model, FLAGS.export_dir)

tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -99,34 +99,15 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0):
9999
if hasattr(de_var, 'saveable'):
100100
de_var.saveable._saver_config.save_path = de_dir
101101

102-
def _save_de_var(de_var, proc_size=1, proc_rank=0):
103-
a2a_emb = de_var._created_in_class
104-
if de_var._saveable_object_creator is not None:
105-
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
106-
# This function only serves FileSystemSaver.
107-
return
108-
# save optimizer parameters of Dynamic Embedding
109-
if include_optimizer is True:
110-
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
111-
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
112-
for de_opt_var in de_opt_vars:
113-
de_opt_var.save_to_file_system(dirpath=de_dir,
114-
proc_size=proc_size,
115-
proc_rank=proc_rank)
116-
if proc_rank == 0:
117-
# FileSystemSaver works well at rank 0.
118-
return
119-
# save Dynamic Embedding Parameters
120-
de_var.save_to_file_system(dirpath=de_dir,
121-
proc_size=proc_size,
122-
proc_rank=proc_rank)
123-
124102
def _maybe_save_restrict_policy_params(de_var, proc_size=1, proc_rank=0):
125103
if not hasattr(de_var, "restrict_policy"):
126104
return
127105
if de_var.restrict_policy is not None:
106+
# Only save restrict policy var if policy created
128107
de_var = de_var.restrict_policy._restrict_var
129-
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)
108+
de_var.save_to_file_system(dirpath=de_dir,
109+
proc_size=proc_size,
110+
proc_rank=proc_rank)
130111

131112
def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
132113
for var in model.variables:
@@ -135,7 +116,27 @@ def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
135116
if not hasattr(var.params, "_created_in_class"):
136117
continue
137118
de_var = var.params
138-
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)
119+
a2a_emb = de_var._created_in_class
120+
if de_var._saveable_object_creator is not None:
121+
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
122+
# This function only serves FileSystemSaver.
123+
return
124+
# save optimizer parameters of Dynamic Embedding
125+
if include_optimizer is True:
126+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
127+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
128+
for de_opt_var in de_opt_vars:
129+
de_opt_var.save_to_file_system(dirpath=de_dir,
130+
proc_size=proc_size,
131+
proc_rank=proc_rank)
132+
if proc_rank == 0:
133+
# FileSystemSaver works well at rank 0.
134+
return
135+
# save Dynamic Embedding Parameters
136+
de_var.save_to_file_system(dirpath=de_dir,
137+
proc_size=proc_size,
138+
proc_rank=proc_rank)
139+
# Save restrict policy for each hvd.rank()
139140
_maybe_save_restrict_policy_params(de_var, proc_size=proc_size, proc_rank=proc_rank)
140141

141142
if hvd is None:

0 commit comments

Comments
 (0)