Skip to content

Commit 983d1d3

Browse files
MoFHekarhdong
authored andcommitted
[feat] Add the DEHvdSaver class, which is similar to tf.train.Saver and is used to save DE KV files with different rank when using horovod all2all training.
1 parent 9d2585e commit 983d1d3

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from tensorflow_recommenders_addons.dynamic_embedding.python.train.saver import DEHvdSaver
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2023 The TensorFlow Recommenders-Addons Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# lint-as: python3
16+
17+
import os.path
18+
19+
from tensorflow_recommenders_addons import dynamic_embedding as de
20+
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import _DynamicEmbeddingSaver
21+
22+
from tensorflow.python.client import session
23+
from tensorflow.python.eager import context
24+
from tensorflow.python.framework import dtypes
25+
from tensorflow.python.framework import errors
26+
from tensorflow.python.framework import ops
27+
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.platform import gfile
29+
from tensorflow.python.training import training_util
30+
from tensorflow.python.util import compat
31+
32+
33+
class DEHvdSaver(_DynamicEmbeddingSaver):
34+
35+
def save(self,
36+
sess,
37+
save_path,
38+
global_step=None,
39+
latest_filename=None,
40+
meta_graph_suffix="meta",
41+
write_meta_graph=True,
42+
write_state=True,
43+
strip_default_attrs=False,
44+
save_debug_info=False,
45+
*args,
46+
**kwargs):
47+
"""Overwrite tf.train.Saver class
48+
Calling the TF save API for all ranks causes file conflicts,
49+
so KV files other than rank0 need to be saved by calling the underlying API separately.
50+
This is a convenience function for saving HvdAllToAllEmbedding to KV files in different rank.
51+
"""
52+
try:
53+
import horovod.tensorflow as hvd
54+
try:
55+
hvd.rank()
56+
except:
57+
hvd = None
58+
except:
59+
hvd = None
60+
61+
def _saver_save():
62+
return super(DEHvdSaver,
63+
self).save(sess=sess,
64+
save_path=save_path,
65+
global_step=global_step,
66+
latest_filename=latest_filename,
67+
meta_graph_suffix=meta_graph_suffix,
68+
write_meta_graph=write_meta_graph,
69+
write_state=write_state,
70+
strip_default_attrs=strip_default_attrs,
71+
save_debug_info=save_debug_info,
72+
*args,
73+
**kwargs)
74+
75+
if hvd is None:
76+
return _saver_save()
77+
else:
78+
if hvd.rank() == 0:
79+
return _saver_save()
80+
else:
81+
save_path = compat.as_str(save_path)
82+
if global_step is not None:
83+
if not isinstance(global_step, compat.integral_types):
84+
global_step = training_util.global_step(sess, global_step)
85+
else:
86+
if os.path.basename(
87+
save_path) == latest_filename and not self._sharded:
88+
# Guard against collision between data file and checkpoint state file.
89+
raise ValueError(
90+
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
91+
(latest_filename, save_path))
92+
93+
if (not context.executing_eagerly()
94+
and not isinstance(sess, session.SessionInterface)):
95+
raise TypeError("'sess' must be a Session; %s" % sess)
96+
97+
save_path_parent = os.path.dirname(save_path)
98+
99+
if global_step is not None:
100+
de_variable_folder_dir = os.path.join(
101+
save_path_parent, "TFRADynamicEmbedding-{}".format(global_step))
102+
if self._pad_step_number:
103+
# Zero-pads the step numbers, so that they are sorted when listed.
104+
de_variable_folder_dir = os.path.join(
105+
save_path_parent,
106+
"TFRADynamicEmbedding-{:08d}".format(global_step))
107+
else:
108+
de_variable_folder_dir = os.path.join(save_path_parent,
109+
"TFRADynamicEmbedding")
110+
if not self._is_empty:
111+
try:
112+
if context.executing_eagerly():
113+
with ops.name_scope("FileSystemSaver", "save_to_file_system",
114+
[]) as name:
115+
self._de_var_fs_save_dir = array_ops.placeholder(
116+
dtype=dtypes.string,
117+
shape=(),
118+
name="de_var_file_system_save_dir")
119+
self._de_save_ops = self._get_dynamic_embedding_save_ops()
120+
else:
121+
sess.run(self._de_save_ops,
122+
{self._de_var_fs_save_dir: de_variable_folder_dir})
123+
except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
124+
if not gfile.IsDirectory(save_path_parent):
125+
exc = ValueError(
126+
"Parent directory of {} doesn't exist, can't save.".format(
127+
save_path))
128+
raise exc

0 commit comments

Comments
 (0)