@@ -172,6 +172,70 @@ def restore(self, file_prefix, options=None):
172
172
173
173
class _DynamicEmbeddingSaver (saver .Saver ):
174
174
175
+ def _get_dynamic_embedding_save_ops (self ):
176
+ save_ops = tf_utils .ListWrapper ([])
177
+ if not self ._var_list :
178
+ return save_ops
179
+
180
+ for var in self ._var_list :
181
+ de_var = None
182
+ if isinstance (var , (de .FileSystemSaver ._DynamicEmbeddingShardFileSystemSaveable ,
183
+ de .FileSystemSaver ._DynamicEmbeddingVariabelFileSystemSaveable )):
184
+ de_var = var ._de_variable
185
+ elif isinstance (var , de .Variable ) and var ._saveable_object_creator :
186
+ de_var = var
187
+
188
+ if de_var and isinstance (de_var ._saveable_object_creator , de .FileSystemSaver ):
189
+ if de_var ._saveable_object_creator .config .save_path :
190
+ de_variable_folder_dir = de_var ._saveable_object_creator .config .save_path
191
+ else :
192
+ de_variable_folder_dir = self ._de_var_fs_save_dir
193
+
194
+ save_op = de_var .save_to_file_system (
195
+ dirpath = de_variable_folder_dir ,
196
+ proc_size = de_var ._saveable_object_creator .config .proc_size ,
197
+ proc_rank = de_var ._saveable_object_creator .config .proc_rank ,
198
+ buffer_size = de_var ._saveable_object_creator .config .buffer_size )
199
+ save_ops .as_list ().append (save_op )
200
+ return control_flow_ops .group (save_ops .as_list ())
201
+
202
+ def _get_dynamic_embedding_restore_ops (self ):
203
+ restore_ops = tf_utils .ListWrapper ([])
204
+ if not self ._var_list :
205
+ return restore_ops
206
+
207
+ for var in self ._var_list :
208
+ de_var = None
209
+ if isinstance (var , (de .FileSystemSaver ._DynamicEmbeddingShardFileSystemSaveable ,
210
+ de .FileSystemSaver ._DynamicEmbeddingVariabelFileSystemSaveable )):
211
+ de_var = var ._de_variable
212
+ elif isinstance (var , de .Variable ) and var ._saveable_object_creator :
213
+ de_var = var
214
+
215
+ if de_var and isinstance (de_var ._saveable_object_creator , de .FileSystemSaver ):
216
+ if de_var ._saveable_object_creator .config .save_path :
217
+ de_variable_folder_dir = de_var ._saveable_object_creator .config .save_path
218
+ else :
219
+ de_variable_folder_dir = self ._de_var_fs_save_dir
220
+
221
+ restore_op = de_var .load_from_file_system_with_restore_function (
222
+ dirpath = de_variable_folder_dir ,
223
+ proc_size = de_var ._saveable_object_creator .config .proc_size ,
224
+ proc_rank = de_var ._saveable_object_creator .config .proc_rank ,
225
+ buffer_size = de_var ._saveable_object_creator .config .buffer_size )
226
+ restore_ops .as_list ().append (restore_op )
227
+ return control_flow_ops .group (restore_ops .as_list ())
228
+
229
+ def _build (self , checkpoint_path , build_save , build_restore ):
230
+ super (_DynamicEmbeddingSaver , self )._build (
231
+ checkpoint_path , build_save , build_restore )
232
+
233
+ with ops .name_scope ("FileSystemSaver" , "save_to_file_system" , []) as name :
234
+ self ._de_var_fs_save_dir = array_ops .placeholder (
235
+ dtype = dtypes .string , shape = (), name = "de_var_file_system_save_dir" )
236
+ self ._de_save_ops = self ._get_dynamic_embedding_save_ops ()
237
+ self ._de_restore_ops = self ._get_dynamic_embedding_restore_ops ()
238
+
175
239
def save (self ,
176
240
sess ,
177
241
save_path ,
@@ -271,52 +335,25 @@ def save(self,
271
335
272
336
save_path_parent = os .path .dirname (save_path )
273
337
274
- def _get_save_ops_list ():
275
- save_ops = tf_utils .ListWrapper ([])
276
- if self ._var_list :
277
- for var in self ._var_list :
278
- if isinstance (var , de .Variable ):
279
- if var ._saveable_object_creator :
280
- if type (
281
- var ._saveable_object_creator ).__name__ == 'FileSystemSaver' :
282
- if var ._saveable_object_creator .config .save_path :
283
- de_variable_folder_dir = var ._saveable_object_creator .config .save_path
284
- elif global_step is not None :
285
- de_variable_folder_dir = "TFRADynamicEmbedding-%d" % (
286
- save_path_parent , global_step )
287
- if self ._pad_step_number :
288
- # Zero-pads the step numbers, so that they are sorted when listed.
289
- de_variable_folder_dir = "TFRADynamicEmbedding-%s" % (
290
- save_path_parent , "{:08d}" .format (global_step ))
291
- else :
292
- de_variable_folder_dir = os .path .join (save_path_parent ,
293
- 'TFRADynamicEmbedding' )
294
- proc_size = var ._saveable_object_creator .config .proc_size
295
- proc_rank = var ._saveable_object_creator .config .proc_rank
296
- buffer_size = var ._saveable_object_creator .config .buffer_size
297
- save_ops .as_list ().append (
298
- var .save_to_file_system (dirpath = de_variable_folder_dir ,
299
- proc_size = proc_size ,
300
- proc_rank = proc_rank ,
301
- buffer_size = buffer_size ))
302
- return save_ops
338
+ if global_step is not None :
339
+ de_variable_folder_dir = os .path .join (
340
+ save_path_parent , "TFRADynamicEmbedding-{}" .format (global_step ))
341
+ if self ._pad_step_number :
342
+ # Zero-pads the step numbers, so that they are sorted when listed.
343
+ de_variable_folder_dir = os .path .join (
344
+ save_path_parent , "TFRADynamicEmbedding-{:08d}" .format (global_step ))
345
+ else :
346
+ de_variable_folder_dir = os .path .join (
347
+ save_path_parent , "TFRADynamicEmbedding" )
303
348
304
349
if not self ._is_empty :
305
350
try :
306
- if context .executing_eagerly ():
307
- self ._build_eager (checkpoint_file ,
308
- build_save = True ,
309
- build_restore = False )
310
- model_checkpoint_path = self .saver_def .save_tensor_name
311
- save_ops = _get_save_ops_list ().as_list ()
312
- else :
351
+ if not context .executing_eagerly ():
313
352
model_checkpoint_path = sess .run (
314
353
self .saver_def .save_tensor_name ,
315
354
{self .saver_def .filename_tensor_name : checkpoint_file })
316
- save_ops_list = _get_save_ops_list ()
317
- if save_ops_list .as_list ():
318
- for save_op in save_ops_list .as_list ():
319
- sess .run (save_op )
355
+ sess .run (self ._de_save_ops ,
356
+ {self ._de_var_fs_save_dir : de_variable_folder_dir })
320
357
321
358
model_checkpoint_path = compat .as_str (model_checkpoint_path )
322
359
if write_state :
@@ -380,45 +417,21 @@ def restore(self, sess, save_path):
380
417
tf_logging .info ("Restoring parameters from %s" , checkpoint_prefix )
381
418
save_path_parent = os .path .dirname (save_path )
382
419
383
- def _get_restore_ops_list ():
384
- restore_ops = tf_utils .ListWrapper ([])
385
- if self ._var_list :
386
- for var in self ._var_list :
387
- if isinstance (var , de .Variable ):
388
- if var ._saveable_object_creator :
389
- if type (
390
- var ._saveable_object_creator ).__name__ == 'FileSystemSaver' :
391
- maybe_global_step = (os .path .basename (save_path )).split ('-' )[- 1 ]
392
- matched_de_dir = os .path .join (
393
- save_path_parent ,
394
- "TFRADynamicEmbedding-" + maybe_global_step )
395
- if var ._saveable_object_creator .config .save_path :
396
- de_variable_folder_dir = var ._saveable_object_creator .config .save_path
397
- elif os .path .exists (matched_de_dir ):
398
- de_variable_folder_dir = matched_de_dir
399
- else :
400
- de_variable_folder_dir = os .path .join (save_path_parent ,
401
- 'TFRADynamicEmbedding' )
402
- proc_rank = var ._saveable_object_creator .config .proc_rank
403
- proc_size = var ._saveable_object_creator .config .proc_size
404
- buffer_size = var ._saveable_object_creator .config .buffer_size
405
- restore_ops .as_list ().append (
406
- var .load_from_file_system_with_restore_function (
407
- de_variable_folder_dir , proc_size , proc_rank ,
408
- buffer_size ))
409
- return restore_ops
420
+ maybe_global_step = os .path .basename (save_path ).split ('-' )[- 1 ]
421
+ matched_de_dir = os .path .join (save_path_parent ,
422
+ 'TFRADynamicEmbedding-' + maybe_global_step )
423
+ if os .path .exists (matched_de_dir ):
424
+ de_variable_folder_dir = matched_de_dir
425
+ else :
426
+ de_variable_folder_dir = os .path .join (save_path_parent ,
427
+ 'TFRADynamicEmbedding' )
410
428
411
429
try :
412
- if context .executing_eagerly ():
413
- self ._build_eager (save_path , build_save = False , build_restore = True )
414
- restore_ops = _get_restore_ops_list ().as_list ()
415
- else :
430
+ if not context .executing_eagerly ():
416
431
sess .run (self .saver_def .restore_op_name ,
417
432
{self .saver_def .filename_tensor_name : save_path })
418
- restore_ops_list = _get_restore_ops_list ()
419
- if restore_ops_list .as_list ():
420
- for restore_op in restore_ops_list .as_list ():
421
- sess .run (restore_op )
433
+ sess .run (self ._de_restore_ops ,
434
+ {self ._de_var_fs_save_dir : de_variable_folder_dir })
422
435
except errors .NotFoundError as err :
423
436
# There are three common conditions that might cause this error:
424
437
# 0. The file is missing. We ignore here, as this is checked above.
0 commit comments