Skip to content

Commit 6f7bbb8

Browse files
MoFHekarhdong
authored andcommitted
[feat] Add support to tf.train.Checkpoint and tf.train.CheckpointManager when using HvdAllToAllEmbedding by calling de.train.DEHvdCheckpoint.
1 parent 0336e59 commit 6f7bbb8

File tree

5 files changed

+347
-0
lines changed

5 files changed

+347
-0
lines changed

tensorflow_recommenders_addons/dynamic_embedding/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
'enable_train_mode',
4040
'get_model_mode',
4141
'trainable_wrapper_filter',
42+
'train',
4243
'keras',
4344
'math',
4445
'data_flow',
4546
'shadow_ops',
4647
]
4748

49+
from tensorflow_recommenders_addons.dynamic_embedding.python import train
4850
from tensorflow_recommenders_addons.dynamic_embedding.python import keras
4951
from tensorflow_recommenders_addons.dynamic_embedding.python.ops import math_ops as math
5052
from tensorflow_recommenders_addons.dynamic_embedding.python.ops import data_flow_ops as data_flow
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from tensorflow_recommenders_addons.dynamic_embedding.python import keras
2+
from tensorflow_recommenders_addons.dynamic_embedding.python import train

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,26 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
326326
de.keras.models.de_hvd_save_model(base_model,
327327
save_dir,
328328
options=save_options)
329+
ckpt = de.train.DEHvdCheckpoint(base_model)
330+
ckpt.save(save_dir + '/ckpt/test')
331+
tf.keras.backend.clear_session()
329332
del base_model
330333
new_base_model = get_emb_sequential_model(
331334
de.keras.layers.HvdAllToAllEmbedding,
332335
base_opt,
336+
dense_init='ones',
333337
embedding_size=dim,
334338
initializer=init,
335339
bp_v2=False,
336340
kv_creator=kv_creator,
337341
name='all2all_emb')
342+
ckpt = de.train.DEHvdCheckpoint(new_base_model)
338343
hvd.join() # Sync for avoiding files conflict
344+
ckpt.restore(tf.train.latest_checkpoint(save_dir + '/ckpt/'))
345+
new_a2aemb_size = new_base_model.layers[0].params.size()
346+
self.assertEqual(a2aemb_size, new_a2aemb_size)
347+
hvd.join() # Sync for avoiding files conflict
348+
tf.keras.backend.clear_session()
339349
new_base_model.load_weights(save_dir + '/variables/variables')
340350
new_a2aemb_size = new_base_model.layers[0].params.size()
341351
self.assertEqual(a2aemb_size, new_a2aemb_size)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from tensorflow_recommenders_addons.dynamic_embedding.python.train.saver import DEHvdSaver
2+
from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DEHvdCheckpoint
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# Copyright 2023 The TensorFlow Recommenders-Addons Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# lint-as: python3
16+
17+
import os.path
18+
import re
19+
20+
from tensorflow_recommenders_addons import dynamic_embedding as de
21+
from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers import HvdAllToAllEmbedding
22+
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import TrainableWrapper, DEResourceVariable
23+
24+
from tensorflow.python.framework import constant_op
25+
try:
26+
from tensorflow.python.checkpoint.checkpoint import Checkpoint
27+
except:
28+
from tensorflow.python.training.tracking.util import Checkpoint
29+
from tensorflow.python.lib.io import file_io
30+
from tensorflow.python.platform import tf_logging
31+
32+
33+
class DEHvdCheckpoint(Checkpoint):
34+
"""Overwrite tf.train.Saver class
35+
Calling the TF save API for all ranks causes file conflicts,
36+
so KV files other than rank0 need to be saved by calling the underlying API separately.
37+
This is a convenience function for saving HvdAllToAllEmbedding to KV files in different rank.
38+
"""
39+
40+
def __init__(self, root=None, **kwargs):
41+
"""Creates a training checkpoint for a single or group of objects.
42+
43+
Args:
44+
root: The root object to checkpoint. `root` may be a trackable object or
45+
`WeakRef` of a trackable object.
46+
**kwargs: Keyword arguments are set as attributes of this object, and are
47+
saved with the checkpoint. All `kwargs` must be trackable objects, or a
48+
nested structure of trackable objects (`list`, `dict`, or `tuple`).
49+
50+
Raises:
51+
ValueError: If `root` or the objects in `kwargs` are not trackable. A
52+
`ValueError` is also raised if the `root` object tracks different
53+
objects from the ones listed in attributes in kwargs (e.g.
54+
`root.child = A` and `tf.train.Checkpoint(root, child=B)` are
55+
incompatible).
56+
57+
"""
58+
try:
59+
import horovod.tensorflow as hvd
60+
try:
61+
hvd.rank()
62+
self._hvd = hvd
63+
except:
64+
self._hvd = None
65+
except:
66+
self._hvd = None
67+
68+
self._tmp_var_key_set = set({})
69+
for k, _ in sorted(kwargs.items(), key=lambda item: item[0]):
70+
self._tmp_var_key_set.add(k)
71+
super(DEHvdCheckpoint, self).__init__(root, **kwargs)
72+
73+
def _get_de_variable_folder_dir(self,
74+
save_path: str,
75+
global_step: str = None):
76+
save_path_parent = os.path.dirname(save_path)
77+
if global_step is not None:
78+
de_variable_folder_dir = os.path.join(
79+
save_path_parent, "TFRADynamicEmbedding-{}".format(global_step))
80+
else:
81+
de_variable_folder_dir = os.path.join(save_path_parent,
82+
"TFRADynamicEmbedding")
83+
return de_variable_folder_dir
84+
85+
def _delete_redundant_de_dir(self, ckpt_index_list: list):
86+
if not len(ckpt_index_list) > 0:
87+
return
88+
save_path_parent = os.path.dirname(ckpt_index_list[0])
89+
de_dir_pattern = os.path.join(save_path_parent, "TFRADynamicEmbedding-*")
90+
found_de_dir_set = set(file_io.get_matching_files(de_dir_pattern))
91+
keep_de_dir_set = set([])
92+
for file_path in ckpt_index_list:
93+
global_step = file_path.split('.index')[-2].split('-')[-1]
94+
de_dir = os.path.join(save_path_parent,
95+
"TFRADynamicEmbedding-{}".format(global_step))
96+
keep_de_dir_set.add(de_dir)
97+
delete_de_dir_set = found_de_dir_set - keep_de_dir_set
98+
for de_dir in delete_de_dir_set:
99+
if file_io.is_directory(de_dir):
100+
file_io.delete_recursively(de_dir)
101+
102+
def _de_var_fs_save_funtion(self, de_var, de_dir: str):
103+
a2a_emb = de_var._created_in_class
104+
hvd_size = 1 if self._hvd is None else self._hvd.size()
105+
hvd_rank = 0 if self._hvd is None else self._hvd.rank()
106+
if issubclass(a2a_emb.__class__, HvdAllToAllEmbedding):
107+
if de_var._saveable_object_creator is None:
108+
tf_logging.warning(
109+
"Please use FileSystemSaver when use HvdAllToAllEmbedding. "
110+
"It will allow TFRA load KV files when Embedding tensor parallel. "
111+
f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}"
112+
)
113+
else:
114+
# save Dynamic Embedding Parameters
115+
de_var.save_to_file_system(dirpath=de_dir,
116+
proc_size=hvd_size,
117+
proc_rank=hvd_rank)
118+
# save optimizer parameters of Dynamic Embedding
119+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
120+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
121+
for de_opt_var in de_opt_vars:
122+
de_opt_var.save_to_file_system(dirpath=de_dir,
123+
proc_size=hvd_size,
124+
proc_rank=hvd_rank)
125+
126+
def _de_var_fs_restore_funtion(self, de_var, de_dir: str):
127+
a2a_emb = de_var._created_in_class
128+
hvd_size = 1 if self._hvd is None else self._hvd.size()
129+
hvd_rank = 0 if self._hvd is None else self._hvd.rank()
130+
if issubclass(a2a_emb.__class__, HvdAllToAllEmbedding):
131+
if de_var._saveable_object_creator is None:
132+
tf_logging.warning(
133+
"Please use FileSystemSaver when use HvdAllToAllEmbedding. "
134+
"It will allow TFRA load KV files when Embedding tensor parallel. "
135+
f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}"
136+
)
137+
else:
138+
# restore Dynamic Embedding Parameters
139+
de_var.load_from_file_system_with_restore_function(dirpath=de_dir,
140+
proc_size=hvd_size,
141+
proc_rank=hvd_rank)
142+
# restore optimizer parameters of Dynamic Embedding
143+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
144+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
145+
for de_opt_var in de_opt_vars:
146+
de_opt_var.load_from_file_system_with_restore_function(
147+
dirpath=de_dir, proc_size=hvd_size, proc_rank=hvd_rank)
148+
149+
def _de_handle_root_and_var_with_func(self, de_dir: str, func):
150+
151+
def _filter_de_hvd_a2a_tw(var):
152+
if not hasattr(var, "params") or not isinstance(var, TrainableWrapper):
153+
return False
154+
if not hasattr(var.params, "_created_in_class"):
155+
return False
156+
return True
157+
158+
if _filter_de_hvd_a2a_tw(self.root):
159+
func(var.params, de_dir)
160+
if hasattr(self.root, 'variables'):
161+
for var in self.root.variables:
162+
if _filter_de_hvd_a2a_tw(var):
163+
func(var.params, de_dir)
164+
if len(self._tmp_var_key_set):
165+
for var_key in self._tmp_var_key_set:
166+
var = getattr(self, var_key)
167+
if _filter_de_hvd_a2a_tw(var):
168+
func(var.params, de_dir)
169+
170+
def _de_hvd_write_fs_func(self, file_prefix, tf_write_func):
171+
172+
def _get_de_dir_from_file_path(file_path):
173+
file_prefix_split = file_path.split('-')
174+
file_prefix_pattern = ''.join(file_prefix_split[0:-1])
175+
global_step = file_prefix_split[-1]
176+
if not global_step.isdigit():
177+
global_step = None
178+
de_dir = self._get_de_variable_folder_dir(file_path, global_step)
179+
return file_prefix_pattern, global_step, de_dir
180+
181+
if self._hvd is None:
182+
file_path = tf_write_func()
183+
self._de_handle_root_and_var_with_func(de_dir=de_dir,
184+
func=self._de_var_fs_save_funtion)
185+
else:
186+
file_path = ''
187+
if self._hvd.rank() == 0:
188+
file_path = tf_write_func()
189+
self._hvd.broadcast_object(file_path,
190+
root_rank=0,
191+
name='de_hvd_broadcast_file_path')
192+
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
193+
file_path)
194+
if global_step is not None:
195+
ckpt_index_list = file_io.get_matching_files(file_prefix_pattern +
196+
'-*.index')
197+
self._delete_redundant_de_dir(
198+
ckpt_index_list
199+
) # Compatible with automatic sweep function of checkpointmanager
200+
self._hvd.join() # Sync for avoiding files conflict
201+
self._de_handle_root_and_var_with_func(
202+
de_dir=de_dir, func=self._de_var_fs_save_funtion)
203+
self._hvd.join(
204+
) # Sync for avoiding files conflict and rank finish early
205+
else:
206+
file_path = self._hvd.broadcast_object(
207+
None, root_rank=0, name='de_hvd_broadcast_file_path')
208+
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
209+
file_path)
210+
self._hvd.join() # Sync for avoiding files conflict
211+
self._de_handle_root_and_var_with_func(
212+
de_dir=de_dir, func=self._de_var_fs_save_funtion)
213+
self._hvd.join(
214+
) # Sync for avoiding files conflict and rank finish early
215+
return file_path
216+
217+
def _write(self, file_prefix, options=None, *args, **kwargs):
218+
"""Internal method that implements Checkpoint.write().
219+
220+
Args:
221+
file_prefix: A prefix to use for the checkpoint filenames
222+
(/path/to/directory/and_a_prefix).
223+
options: Optional `tf.train.CheckpointOptions` object.
224+
write_done_callback: Optional callback function to be executed once
225+
the underlying checkpoint saving is finished. Example usage includes
226+
updating the checkpoint internal state.
227+
228+
Returns:
229+
The full path to the checkpoint (i.e. `file_prefix`).
230+
"""
231+
232+
def tf_write_func_impl():
233+
return super(DEHvdCheckpoint, self)._write(file_prefix=file_prefix,
234+
options=options,
235+
*args,
236+
**kwargs)
237+
238+
return self._de_hvd_write_fs_func(file_prefix=file_prefix,
239+
tf_write_func=tf_write_func_impl)
240+
241+
def write(self, file_prefix, options=None, *args, **kwargs):
242+
"""
243+
Args:
244+
file_prefix: A prefix to use for the checkpoint filenames
245+
(/path/to/directory/and_a_prefix).
246+
options: Optional `tf.train.CheckpointOptions` object.
247+
248+
Returns:
249+
The full path to the checkpoint (i.e. `file_prefix`).
250+
"""
251+
252+
def tf_write_func_impl():
253+
if hasattr(super(DEHvdCheckpoint, self), '_write'):
254+
return super(DEHvdCheckpoint, self)._write(file_prefix=file_prefix,
255+
options=options,
256+
*args,
257+
**kwargs)
258+
else:
259+
return super(DEHvdCheckpoint, self).write(file_prefix=file_prefix,
260+
options=options,
261+
*args,
262+
**kwargs)
263+
264+
return self._de_hvd_write_fs_func(file_prefix=file_prefix,
265+
tf_write_func=tf_write_func_impl)
266+
267+
def restore(self, save_path, options=None, *args, **kwargs):
268+
"""
269+
Args:
270+
save_path: The path to the checkpoint, as returned by `save` or
271+
`tf.train.latest_checkpoint`. If None (as when there is no latest
272+
checkpoint for `tf.train.latest_checkpoint` to return), returns an
273+
object which may run initializers for objects in the dependency graph.
274+
If the checkpoint was written by the name-based
275+
`tf.compat.v1.train.Saver`, names are used to match variables.
276+
options: Optional `tf.train.CheckpointOptions` object.
277+
278+
Returns:
279+
A load status object, which can be used to make assertions about the
280+
status of checkpoint restoration and run initialization/restore ops
281+
(of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
282+
`save_path` is `None`).
283+
284+
If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
285+
object is returned which runs restore ops from a name-based saver.
286+
287+
Raises:
288+
RuntimeError: When a checkpoint file saved by async checkpoint is not
289+
available upon restore().
290+
"""
291+
save_path_split = save_path.split('-')
292+
save_path_pattern = ''.join(save_path_split[0:-1])
293+
global_step = save_path_split[-1]
294+
if not global_step.isdigit():
295+
global_step = None
296+
de_dir = self._get_de_variable_folder_dir(save_path, global_step)
297+
298+
impl_save_path = save_path
299+
if 'TFRADynamicEmbedding' in save_path:
300+
tf_logging.warning(
301+
f'''Arg save_path is {save_path}. Please do not name checkpoint with \'TFRADynamicEmbedding\', it is a special term.
302+
If you are sure that this is not the name of checkpoint,
303+
it is an unfixed bug related to tf.train.latest_checkpoint.
304+
Please call restore function directly with the name of checkpoint.''')
305+
if global_step is not None:
306+
corresponding_ckpt_index = file_io.get_matching_files(
307+
os.path.join(os.path.dirname(save_path), f'*-{global_step}.index'))
308+
else:
309+
corresponding_ckpt_index = file_io.get_matching_files(
310+
os.path.join(os.path.dirname(save_path), '*.index'))
311+
de_dir = self._get_de_variable_folder_dir(
312+
save_path,
313+
(corresponding_ckpt_index[0].split('-')[-1].split('.index')[0]))
314+
if len(corresponding_ckpt_index) > 0:
315+
impl_save_path = corresponding_ckpt_index[0].split('.index')[0]
316+
if global_step is None:
317+
tf_logging.warning(
318+
f'Arg save_path {save_path} is illegal or not existing. Now using index {impl_save_path}'
319+
)
320+
321+
result = super(DEHvdCheckpoint, self).restore(save_path=impl_save_path,
322+
options=options,
323+
*args,
324+
**kwargs)
325+
if os.path.exists(de_dir):
326+
self._de_handle_root_and_var_with_func(
327+
de_dir=de_dir, func=self._de_var_fs_restore_funtion)
328+
else:
329+
tf_logging.warning(
330+
f'TFRADynamicEmbedding directory {de_dir} is not existing.')
331+
if self._hvd is not None:
332+
self._hvd.join() # Sync for avoiding files conflict
333+
return result

0 commit comments

Comments
 (0)