Skip to content

Commit 3855831

Browse files
fix: restrict policy var save for distributed setup
1 parent b3bc3d4 commit 3855831

File tree

1 file changed

+30
-21
lines changed
  • tensorflow_recommenders_addons/dynamic_embedding/python/keras

1 file changed

+30
-21
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -98,34 +98,43 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0):
9898
# Redirect new de_dir
9999
if hasattr(de_var, 'saveable'):
100100
de_var.saveable._saver_config.save_path = de_dir
101-
101+
102+
def _save_de_var(de_var, proc_size=1, proc_rank=0):
103+
a2a_emb = de_var._created_in_class
104+
if de_var._saveable_object_creator is not None:
105+
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
106+
# This function only serves FileSystemSaver.
107+
return
108+
# save optimizer parameters of Dynamic Embedding
109+
if include_optimizer is True:
110+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
111+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
112+
for de_opt_var in de_opt_vars:
113+
de_opt_var.save_to_file_system(dirpath=de_dir,
114+
proc_size=proc_size,
115+
proc_rank=proc_rank)
116+
if proc_rank == 0:
117+
# FileSystemSaver works well at rank 0.
118+
return
119+
# save Dynamic Embedding Parameters
120+
de_var.save_to_file_system(dirpath=de_dir,
121+
proc_size=proc_size,
122+
proc_rank=proc_rank)
123+
124+
def _maybe_save_restrict_policy_params(var, proc_size=1, proc_rank=0):
125+
if var.restrict_policy is not None:
126+
de_var = var.restrict_policy._restrict_var
127+
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)
128+
102129
def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
103130
for var in model.variables:
104131
if not hasattr(var, "params"):
105132
continue
106133
if not hasattr(var.params, "_created_in_class"):
107134
continue
108135
de_var = var.params
109-
a2a_emb = de_var._created_in_class
110-
if de_var._saveable_object_creator is not None:
111-
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
112-
# This function only serves FileSystemSaver.
113-
continue
114-
# save optimizer parameters of Dynamic Embedding
115-
if include_optimizer is True:
116-
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
117-
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
118-
for de_opt_var in de_opt_vars:
119-
de_opt_var.save_to_file_system(dirpath=de_dir,
120-
proc_size=proc_size,
121-
proc_rank=proc_rank)
122-
if proc_rank == 0:
123-
# FileSystemSaver works well at rank 0.
124-
continue
125-
# save Dynamic Embedding Parameters
126-
de_var.save_to_file_system(dirpath=de_dir,
127-
proc_size=proc_size,
128-
proc_rank=proc_rank)
136+
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)
137+
_maybe_save_restrict_policy_params(var, proc_size=proc_size, proc_rank=proc_rank)
129138

130139
if hvd is None:
131140
call_original_save_func()

0 commit comments

Comments
 (0)