Skip to content

Commit b078929

Browse files
MoFHekarhdong
authored andcommitted
[fix] When result of importing horovod is None, TFRA DEHvdCheckpoint would not call DE variable saving and sweeping redundant DE files.
1 parent 6f7bbb8 commit b078929

File tree

1 file changed

+14
-10
lines changed
  • tensorflow_recommenders_addons/dynamic_embedding/python/train

1 file changed

+14
-10
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,20 @@ def _get_de_dir_from_file_path(file_path):
178178
de_dir = self._get_de_variable_folder_dir(file_path, global_step)
179179
return file_prefix_pattern, global_step, de_dir
180180

181+
def _rank0_delete_files_and_return_de_dir(file_path):
182+
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
183+
file_path)
184+
if global_step is not None:
185+
ckpt_index_list = file_io.get_matching_files(file_prefix_pattern +
186+
'-*.index')
187+
self._delete_redundant_de_dir(
188+
ckpt_index_list
189+
) # Compatible with automatic sweep function of checkpointmanager
190+
return de_dir
191+
181192
if self._hvd is None:
182193
file_path = tf_write_func()
194+
de_dir = _rank0_delete_files_and_return_de_dir(file_path)
183195
self._de_handle_root_and_var_with_func(de_dir=de_dir,
184196
func=self._de_var_fs_save_funtion)
185197
else:
@@ -189,14 +201,7 @@ def _get_de_dir_from_file_path(file_path):
189201
self._hvd.broadcast_object(file_path,
190202
root_rank=0,
191203
name='de_hvd_broadcast_file_path')
192-
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
193-
file_path)
194-
if global_step is not None:
195-
ckpt_index_list = file_io.get_matching_files(file_prefix_pattern +
196-
'-*.index')
197-
self._delete_redundant_de_dir(
198-
ckpt_index_list
199-
) # Compatible with automatic sweep function of checkpointmanager
204+
de_dir = _rank0_delete_files_and_return_de_dir(file_path)
200205
self._hvd.join() # Sync for avoiding files conflict
201206
self._de_handle_root_and_var_with_func(
202207
de_dir=de_dir, func=self._de_var_fs_save_funtion)
@@ -205,8 +210,7 @@ def _get_de_dir_from_file_path(file_path):
205210
else:
206211
file_path = self._hvd.broadcast_object(
207212
None, root_rank=0, name='de_hvd_broadcast_file_path')
208-
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
209-
file_path)
213+
_, _, de_dir = _get_de_dir_from_file_path(file_path)
210214
self._hvd.join() # Sync for avoiding files conflict
211215
self._de_handle_root_and_var_with_func(
212216
de_dir=de_dir, func=self._de_var_fs_save_funtion)

0 commit comments

Comments
 (0)