@@ -98,34 +98,43 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0):
98
98
# Redirect new de_dir
99
99
if hasattr (de_var , 'saveable' ):
100
100
de_var .saveable ._saver_config .save_path = de_dir
101
-
101
+
102
+ def _save_de_var (de_var , proc_size = 1 , proc_rank = 0 ):
103
+ a2a_emb = de_var ._created_in_class
104
+ if de_var ._saveable_object_creator is not None :
105
+ if not isinstance (de_var .kv_creator .saver , de .FileSystemSaver ):
106
+ # This function only serves FileSystemSaver.
107
+ return
108
+ # save optimizer parameters of Dynamic Embedding
109
+ if include_optimizer is True :
110
+ de_opt_vars = a2a_emb .optimizer_vars .as_list () if hasattr (
111
+ a2a_emb .optimizer_vars , "as_list" ) else a2a_emb .optimizer_vars
112
+ for de_opt_var in de_opt_vars :
113
+ de_opt_var .save_to_file_system (dirpath = de_dir ,
114
+ proc_size = proc_size ,
115
+ proc_rank = proc_rank )
116
+ if proc_rank == 0 :
117
+ # FileSystemSaver works well at rank 0.
118
+ return
119
+ # save Dynamic Embedding Parameters
120
+ de_var .save_to_file_system (dirpath = de_dir ,
121
+ proc_size = proc_size ,
122
+ proc_rank = proc_rank )
123
+
124
+ def _maybe_save_restrict_policy_params (var , proc_size = 1 , proc_rank = 0 ):
125
+ if var .restrict_policy is not None :
126
+ de_var = var .restrict_policy ._restrict_var
127
+ _save_de_var (de_var , proc_size = proc_size , proc_rank = proc_rank )
128
+
102
129
def _traverse_emb_layers_and_save (proc_size = 1 , proc_rank = 0 ):
103
130
for var in model .variables :
104
131
if not hasattr (var , "params" ):
105
132
continue
106
133
if not hasattr (var .params , "_created_in_class" ):
107
134
continue
108
135
de_var = var .params
109
- a2a_emb = de_var ._created_in_class
110
- if de_var ._saveable_object_creator is not None :
111
- if not isinstance (de_var .kv_creator .saver , de .FileSystemSaver ):
112
- # This function only serves FileSystemSaver.
113
- continue
114
- # save optimizer parameters of Dynamic Embedding
115
- if include_optimizer is True :
116
- de_opt_vars = a2a_emb .optimizer_vars .as_list () if hasattr (
117
- a2a_emb .optimizer_vars , "as_list" ) else a2a_emb .optimizer_vars
118
- for de_opt_var in de_opt_vars :
119
- de_opt_var .save_to_file_system (dirpath = de_dir ,
120
- proc_size = proc_size ,
121
- proc_rank = proc_rank )
122
- if proc_rank == 0 :
123
- # FileSystemSaver works well at rank 0.
124
- continue
125
- # save Dynamic Embedding Parameters
126
- de_var .save_to_file_system (dirpath = de_dir ,
127
- proc_size = proc_size ,
128
- proc_rank = proc_rank )
136
+ _save_de_var (de_var , proc_size = proc_size , proc_rank = proc_rank )
137
+ _maybe_save_restrict_policy_params (var , proc_size = proc_size , proc_rank = proc_rank )
129
138
130
139
if hvd is None :
131
140
call_original_save_func ()
0 commit comments