Skip to content

Commit bbab571

Browse files
PWZERrhdong
authored andcommitted
Fix: de.save_to_file_system not eager mode
1 parent 4e21368 commit bbab571

File tree

2 files changed

+90
-76
lines changed

2 files changed

+90
-76
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _insert_de_shard_from_file_system(
195195
Returns:
196196
traverse_files_result: A tensor from loop result, return False if success.
197197
"""
198-
control_flow_ops.Assert(
198+
check_size_op = control_flow_ops.Assert(
199199
math_ops.equal(array_ops.size(shard_keys_file_list),
200200
array_ops.size(shard_values_file_list)),
201201
[
@@ -228,6 +228,7 @@ def _insert_de_shard_from_file_system(
228228
drop_remainder=False)
229229

230230
iterator_init_list = tf_utils.ListWrapper([])
231+
iterator_init_list.as_list().append(check_size_op)
231232
if context.executing_eagerly():
232233
keys_tensor_iterator = iter(_keys_tensor_dataset)
233234
values_tensor_iterator = iter(_values_tensor_dataset)

tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py

Lines changed: 88 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,70 @@ def restore(self, file_prefix, options=None):
172172

173173
class _DynamicEmbeddingSaver(saver.Saver):
174174

175+
def _get_dynamic_embedding_save_ops(self):
176+
save_ops = tf_utils.ListWrapper([])
177+
if not self._var_list:
178+
return save_ops
179+
180+
for var in self._var_list:
181+
de_var = None
182+
if isinstance(var, (de.FileSystemSaver._DynamicEmbeddingShardFileSystemSaveable,
183+
de.FileSystemSaver._DynamicEmbeddingVariabelFileSystemSaveable)):
184+
de_var = var._de_variable
185+
elif isinstance(var, de.Variable) and var._saveable_object_creator:
186+
de_var = var
187+
188+
if de_var and isinstance(de_var._saveable_object_creator, de.FileSystemSaver):
189+
if de_var._saveable_object_creator.config.save_path:
190+
de_variable_folder_dir = de_var._saveable_object_creator.config.save_path
191+
else:
192+
de_variable_folder_dir = self._de_var_fs_save_dir
193+
194+
save_op = de_var.save_to_file_system(
195+
dirpath=de_variable_folder_dir,
196+
proc_size=de_var._saveable_object_creator.config.proc_size,
197+
proc_rank=de_var._saveable_object_creator.config.proc_rank,
198+
buffer_size=de_var._saveable_object_creator.config.buffer_size)
199+
save_ops.as_list().append(save_op)
200+
return control_flow_ops.group(save_ops.as_list())
201+
202+
def _get_dynamic_embedding_restore_ops(self):
203+
restore_ops = tf_utils.ListWrapper([])
204+
if not self._var_list:
205+
return restore_ops
206+
207+
for var in self._var_list:
208+
de_var = None
209+
if isinstance(var, (de.FileSystemSaver._DynamicEmbeddingShardFileSystemSaveable,
210+
de.FileSystemSaver._DynamicEmbeddingVariabelFileSystemSaveable)):
211+
de_var = var._de_variable
212+
elif isinstance(var, de.Variable) and var._saveable_object_creator:
213+
de_var = var
214+
215+
if de_var and isinstance(de_var._saveable_object_creator, de.FileSystemSaver):
216+
if de_var._saveable_object_creator.config.save_path:
217+
de_variable_folder_dir = de_var._saveable_object_creator.config.save_path
218+
else:
219+
de_variable_folder_dir = self._de_var_fs_save_dir
220+
221+
restore_op = de_var.load_from_file_system_with_restore_function(
222+
dirpath=de_variable_folder_dir,
223+
proc_size=de_var._saveable_object_creator.config.proc_size,
224+
proc_rank=de_var._saveable_object_creator.config.proc_rank,
225+
buffer_size=de_var._saveable_object_creator.config.buffer_size)
226+
restore_ops.as_list().append(restore_op)
227+
return control_flow_ops.group(restore_ops.as_list())
228+
229+
def _build(self, checkpoint_path, build_save, build_restore):
230+
super(_DynamicEmbeddingSaver, self)._build(
231+
checkpoint_path, build_save, build_restore)
232+
233+
with ops.name_scope("FileSystemSaver", "save_to_file_system", []) as name:
234+
self._de_var_fs_save_dir = array_ops.placeholder(
235+
dtype=dtypes.string, shape=(), name="de_var_file_system_save_dir")
236+
self._de_save_ops = self._get_dynamic_embedding_save_ops()
237+
self._de_restore_ops = self._get_dynamic_embedding_restore_ops()
238+
175239
def save(self,
176240
sess,
177241
save_path,
@@ -271,52 +335,25 @@ def save(self,
271335

272336
save_path_parent = os.path.dirname(save_path)
273337

274-
def _get_save_ops_list():
275-
save_ops = tf_utils.ListWrapper([])
276-
if self._var_list:
277-
for var in self._var_list:
278-
if isinstance(var, de.Variable):
279-
if var._saveable_object_creator:
280-
if type(
281-
var._saveable_object_creator).__name__ == 'FileSystemSaver':
282-
if var._saveable_object_creator.config.save_path:
283-
de_variable_folder_dir = var._saveable_object_creator.config.save_path
284-
elif global_step is not None:
285-
de_variable_folder_dir = "TFRADynamicEmbedding-%d" % (
286-
save_path_parent, global_step)
287-
if self._pad_step_number:
288-
# Zero-pads the step numbers, so that they are sorted when listed.
289-
de_variable_folder_dir = "TFRADynamicEmbedding-%s" % (
290-
save_path_parent, "{:08d}".format(global_step))
291-
else:
292-
de_variable_folder_dir = os.path.join(save_path_parent,
293-
'TFRADynamicEmbedding')
294-
proc_size = var._saveable_object_creator.config.proc_size
295-
proc_rank = var._saveable_object_creator.config.proc_rank
296-
buffer_size = var._saveable_object_creator.config.buffer_size
297-
save_ops.as_list().append(
298-
var.save_to_file_system(dirpath=de_variable_folder_dir,
299-
proc_size=proc_size,
300-
proc_rank=proc_rank,
301-
buffer_size=buffer_size))
302-
return save_ops
338+
if global_step is not None:
339+
de_variable_folder_dir = os.path.join(
340+
save_path_parent, "TFRADynamicEmbedding-{}".format(global_step))
341+
if self._pad_step_number:
342+
# Zero-pads the step numbers, so that they are sorted when listed.
343+
de_variable_folder_dir = os.path.join(
344+
save_path_parent, "TFRADynamicEmbedding-{:08d}".format(global_step))
345+
else:
346+
de_variable_folder_dir = os.path.join(
347+
save_path_parent, "TFRADynamicEmbedding")
303348

304349
if not self._is_empty:
305350
try:
306-
if context.executing_eagerly():
307-
self._build_eager(checkpoint_file,
308-
build_save=True,
309-
build_restore=False)
310-
model_checkpoint_path = self.saver_def.save_tensor_name
311-
save_ops = _get_save_ops_list().as_list()
312-
else:
351+
if not context.executing_eagerly():
313352
model_checkpoint_path = sess.run(
314353
self.saver_def.save_tensor_name,
315354
{self.saver_def.filename_tensor_name: checkpoint_file})
316-
save_ops_list = _get_save_ops_list()
317-
if save_ops_list.as_list():
318-
for save_op in save_ops_list.as_list():
319-
sess.run(save_op)
355+
sess.run(self._de_save_ops,
356+
{self._de_var_fs_save_dir: de_variable_folder_dir})
320357

321358
model_checkpoint_path = compat.as_str(model_checkpoint_path)
322359
if write_state:
@@ -380,45 +417,21 @@ def restore(self, sess, save_path):
380417
tf_logging.info("Restoring parameters from %s", checkpoint_prefix)
381418
save_path_parent = os.path.dirname(save_path)
382419

383-
def _get_restore_ops_list():
384-
restore_ops = tf_utils.ListWrapper([])
385-
if self._var_list:
386-
for var in self._var_list:
387-
if isinstance(var, de.Variable):
388-
if var._saveable_object_creator:
389-
if type(
390-
var._saveable_object_creator).__name__ == 'FileSystemSaver':
391-
maybe_global_step = (os.path.basename(save_path)).split('-')[-1]
392-
matched_de_dir = os.path.join(
393-
save_path_parent,
394-
"TFRADynamicEmbedding-" + maybe_global_step)
395-
if var._saveable_object_creator.config.save_path:
396-
de_variable_folder_dir = var._saveable_object_creator.config.save_path
397-
elif os.path.exists(matched_de_dir):
398-
de_variable_folder_dir = matched_de_dir
399-
else:
400-
de_variable_folder_dir = os.path.join(save_path_parent,
401-
'TFRADynamicEmbedding')
402-
proc_rank = var._saveable_object_creator.config.proc_rank
403-
proc_size = var._saveable_object_creator.config.proc_size
404-
buffer_size = var._saveable_object_creator.config.buffer_size
405-
restore_ops.as_list().append(
406-
var.load_from_file_system_with_restore_function(
407-
de_variable_folder_dir, proc_size, proc_rank,
408-
buffer_size))
409-
return restore_ops
420+
maybe_global_step = os.path.basename(save_path).split('-')[-1]
421+
matched_de_dir = os.path.join(save_path_parent,
422+
'TFRADynamicEmbedding-' + maybe_global_step)
423+
if os.path.exists(matched_de_dir):
424+
de_variable_folder_dir = matched_de_dir
425+
else:
426+
de_variable_folder_dir = os.path.join(save_path_parent,
427+
'TFRADynamicEmbedding')
410428

411429
try:
412-
if context.executing_eagerly():
413-
self._build_eager(save_path, build_save=False, build_restore=True)
414-
restore_ops = _get_restore_ops_list().as_list()
415-
else:
430+
if not context.executing_eagerly():
416431
sess.run(self.saver_def.restore_op_name,
417432
{self.saver_def.filename_tensor_name: save_path})
418-
restore_ops_list = _get_restore_ops_list()
419-
if restore_ops_list.as_list():
420-
for restore_op in restore_ops_list.as_list():
421-
sess.run(restore_op)
433+
sess.run(self._de_restore_ops,
434+
{self._de_var_fs_save_dir: de_variable_folder_dir})
422435
except errors.NotFoundError as err:
423436
# There are three common conditions that might cause this error:
424437
# 0. The file is missing. We ignore here, as this is checked above.

0 commit comments

Comments
 (0)