@@ -99,34 +99,15 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0):
99
99
if hasattr (de_var , 'saveable' ):
100
100
de_var .saveable ._saver_config .save_path = de_dir
101
101
102
- def _save_de_var (de_var , proc_size = 1 , proc_rank = 0 ):
103
- a2a_emb = de_var ._created_in_class
104
- if de_var ._saveable_object_creator is not None :
105
- if not isinstance (de_var .kv_creator .saver , de .FileSystemSaver ):
106
- # This function only serves FileSystemSaver.
107
- return
108
- # save optimizer parameters of Dynamic Embedding
109
- if include_optimizer is True :
110
- de_opt_vars = a2a_emb .optimizer_vars .as_list () if hasattr (
111
- a2a_emb .optimizer_vars , "as_list" ) else a2a_emb .optimizer_vars
112
- for de_opt_var in de_opt_vars :
113
- de_opt_var .save_to_file_system (dirpath = de_dir ,
114
- proc_size = proc_size ,
115
- proc_rank = proc_rank )
116
- if proc_rank == 0 :
117
- # FileSystemSaver works well at rank 0.
118
- return
119
- # save Dynamic Embedding Parameters
120
- de_var .save_to_file_system (dirpath = de_dir ,
121
- proc_size = proc_size ,
122
- proc_rank = proc_rank )
123
-
124
102
def _maybe_save_restrict_policy_params (de_var , proc_size = 1 , proc_rank = 0 ):
125
103
if not hasattr (de_var , "restrict_policy" ):
126
104
return
127
105
if de_var .restrict_policy is not None :
106
+ # Only save restrict policy var if policy created
128
107
de_var = de_var .restrict_policy ._restrict_var
129
- _save_de_var (de_var , proc_size = proc_size , proc_rank = proc_rank )
108
+ de_var .save_to_file_system (dirpath = de_dir ,
109
+ proc_size = proc_size ,
110
+ proc_rank = proc_rank )
130
111
131
112
def _traverse_emb_layers_and_save (proc_size = 1 , proc_rank = 0 ):
132
113
for var in model .variables :
@@ -135,7 +116,27 @@ def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
135
116
if not hasattr (var .params , "_created_in_class" ):
136
117
continue
137
118
de_var = var .params
138
- _save_de_var (de_var , proc_size = proc_size , proc_rank = proc_rank )
119
+ a2a_emb = de_var ._created_in_class
120
+ if de_var ._saveable_object_creator is not None :
121
+ if not isinstance (de_var .kv_creator .saver , de .FileSystemSaver ):
122
+ # This function only serves FileSystemSaver.
123
+ return
124
+ # save optimizer parameters of Dynamic Embedding
125
+ if include_optimizer is True :
126
+ de_opt_vars = a2a_emb .optimizer_vars .as_list () if hasattr (
127
+ a2a_emb .optimizer_vars , "as_list" ) else a2a_emb .optimizer_vars
128
+ for de_opt_var in de_opt_vars :
129
+ de_opt_var .save_to_file_system (dirpath = de_dir ,
130
+ proc_size = proc_size ,
131
+ proc_rank = proc_rank )
132
+ if proc_rank == 0 :
133
+ # FileSystemSaver works well at rank 0.
134
+ return
135
+ # save Dynamic Embedding Parameters
136
+ de_var .save_to_file_system (dirpath = de_dir ,
137
+ proc_size = proc_size ,
138
+ proc_rank = proc_rank )
139
+ # Save restrict policy for each hvd.rank()
139
140
_maybe_save_restrict_policy_params (de_var , proc_size = proc_size , proc_rank = proc_rank )
140
141
141
142
if hvd is None :
0 commit comments