Skip to content

Commit 8f0951c

Browse files
MoFHekarhdong
authored andcommitted
[fix] The default save function for tensorflow will not be patched at this time, as this can lead to unexpected errors.
Now using de.keras.models.de_hvd_save_model to replace tf.keras.models.save_model.
1 parent e2c51d7 commit 8f0951c

File tree

7 files changed

+184
-133
lines changed

7 files changed

+184
-133
lines changed

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,14 @@ def export_to_savedmodel(model, savedmodel_dir):
451451
# proc_size=hvd.size(),
452452
# proc_rank=hvd.rank())
453453

454-
# TFRA modify the Keras save function with a monkey patch.
454+
# TFRA modify the Keras save function with a patch.
455455
# !!!! Run save_model function in all rank !!!!
456-
tf.keras.models.save_model(model,
457-
savedmodel_dir,
458-
overwrite=True,
459-
include_optimizer=True,
460-
save_traces=True,
461-
options=save_options)
456+
de.keras.models.de_hvd_save_model(model,
457+
savedmodel_dir,
458+
overwrite=True,
459+
include_optimizer=True,
460+
save_traces=True,
461+
options=save_options)
462462

463463

464464
def export_for_serving(model, export_dir):
@@ -506,9 +506,9 @@ def serve(*args, **kwargs):
506506
# proc_size=hvd.size(),
507507
# proc_rank=hvd.rank())
508508

509-
# TFRA modify the Keras save function with a monkey patch.
509+
# TFRA modify the Keras save function with a patch.
510510
# !!!! Run save_model function in all rank !!!!
511-
tf.keras.models.save_model(
511+
de.keras.models.de_hvd_save_model(
512512
model,
513513
export_dir,
514514
overwrite=True,
@@ -559,6 +559,9 @@ def train():
559559
# horovod callback is used to broadcast the value generated by initializer of rank0.
560560
hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(
561561
root_rank=0)
562+
ckpt_callback = de.keras.callbacks.DEHvdModelCheckpoint(
563+
filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
564+
options=save_options)
562565
callbacks_list = [hvd_opt_init_callback, ckpt_callback]
563566
# The log class callback only takes effect in rank0 for convenience
564567
if hvd.rank() == 0:
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import layers
2-
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import callbacks
2+
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import callbacks
3+
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import models

tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow.python.keras import backend as K
2222
from tensorflow.python.keras import callbacks
2323
from tensorflow.python.keras.utils import tf_utils
24+
from tensorflow.python.ops import array_ops
2425
from tensorflow.python.ops import variables
2526
from tensorflow.python.platform import tf_logging as logging
2627
from tensorflow.python.util.deprecation import deprecated
@@ -108,10 +109,6 @@ def __init__(self, root_rank=0, device='', local_variables=None):
108109
self.register_local_var(var)
109110

110111

111-
@deprecated(
112-
None, "\n!!!! Using this callback will cause a save twice error. !!!!\n"
113-
"The callbacks.ModelCheckpoint for HvdAllToAllEmbedding has been deprecated, use original ModelCheckpoint instead.\n"
114-
"!!!! Using this callback will cause a save twice error. !!!!\n")
115112
class DEHvdModelCheckpoint(callbacks.ModelCheckpoint):
116113

117114
def __init__(self, *args, **kwargs):
@@ -129,23 +126,26 @@ def _save_de_model(self, filepath):
129126
options=self._options)
130127
else:
131128
de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding")
132-
for layer in self.model.layers:
133-
if hasattr(layer, "params") and isinstance(layer, HvdAllToAllEmbedding):
129+
for var in self.model.variables:
130+
if not hasattr(var, "params") or not isinstance(var, TrainableWrapper):
131+
continue
132+
if not hasattr(var.params, "_created_in_class"):
133+
continue
134+
de_var = var.params
135+
a2a_emb = de_var._created_in_class
136+
if issubclass(a2a_emb.__class__, HvdAllToAllEmbedding):
134137
# save Dynamic Embedding Parameters
135-
logging.warning(
136-
"!!!! Using this callback will cause a save twice error. !!!!\n"
137-
"The callbacks.ModelCheckpoint for HvdAllToAllEmbedding has been deprecated, use original ModelCheckpoint instead.\n"
138-
"!!!! Using this callback will cause a save twice error. !!!!\n")
139-
layer.params.save_to_file_system(dirpath=de_dir,
140-
proc_size=hvd.size(),
141-
proc_rank=hvd.rank())
138+
de_var.save_to_file_system(dirpath=de_dir,
139+
proc_size=hvd.size(),
140+
proc_rank=hvd.rank())
142141
# save optimizer parameters of Dynamic Embedding
143-
opt_de_vars = layer.optimizer_vars.as_list() if hasattr(
144-
layer.optimizer_vars, "as_list") else layer.optimizer_vars
145-
for opt_de_var in opt_de_vars:
146-
opt_de_var.save_to_file_system(dirpath=de_dir,
142+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
143+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
144+
for de_opt_var in de_opt_vars:
145+
de_opt_var.save_to_file_system(dirpath=de_dir,
147146
proc_size=hvd.size(),
148147
proc_rank=hvd.rank())
148+
hvd.join() # Sync for avoiding data conflict or missing rank
149149

150150
def _save_model(self, epoch, logs):
151151
"""Saves the model.

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def __init__(self,
550550
else:
551551
self._mpi_size = mpi_size
552552
super(HvdAllToAllEmbedding, self).__init__(*args, **kwargs)
553+
self.params._created_in_class = self
553554

554555
def __relocate_dense_feature__(self, ids, batch_size=None):
555556
"""
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2023 The TensorFlow Recommenders-Addons Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# lint-as: python3
16+
17+
import functools
18+
import os.path
19+
20+
from tensorflow_recommenders_addons import dynamic_embedding as de
21+
22+
try:
23+
from keras.saving.saved_model import save as keras_saved_model_save
24+
except:
25+
keras_saved_model_save = None
26+
from tensorflow.python.keras.saving.saved_model import save as tf_saved_model_save
27+
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.platform import tf_logging
29+
30+
tf_original_save_func = tf_saved_model_save.save
31+
if keras_saved_model_save is not None:
32+
keras_original_save_func = keras_saved_model_save.save
33+
34+
35+
def _de_keras_save_func(original_save_func,
36+
model,
37+
filepath,
38+
overwrite,
39+
include_optimizer,
40+
signatures=None,
41+
options=None,
42+
save_traces=True,
43+
*args,
44+
**kwargs):
45+
"""Overwrite TF Keras save function
46+
Calling the TF save API for all ranks causes file conflicts,
47+
so KV files other than rank0 need to be saved by calling the underlying API separately.
48+
This is a convenience function for saving HvdAllToAllEmbedding to KV files in different rank.
49+
"""
50+
try:
51+
import horovod.tensorflow as hvd
52+
try:
53+
hvd.rank()
54+
except:
55+
hvd = None
56+
except:
57+
hvd = None
58+
59+
call_original_save_func = functools.partial(
60+
original_save_func,
61+
model=model,
62+
filepath=filepath,
63+
overwrite=overwrite,
64+
include_optimizer=include_optimizer,
65+
signatures=signatures,
66+
options=options,
67+
save_traces=save_traces,
68+
*args,
69+
**kwargs)
70+
71+
def _traverse_emb_layers_and_save(hvd_rank):
72+
de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding")
73+
for var in model.variables:
74+
if not hasattr(var, "params"):
75+
continue
76+
if not hasattr(var.params, "_created_in_class"):
77+
continue
78+
de_var = var.params
79+
a2a_emb = de_var._created_in_class
80+
if issubclass(a2a_emb.__class__, de.keras.layers.HvdAllToAllEmbedding):
81+
if de_var._saveable_object_creator is None:
82+
if hvd_rank == 0:
83+
tf_logging.warning(
84+
"Please use FileSystemSaver when use HvdAllToAllEmbedding. "
85+
"It will allow TFRA load KV files when Embedding tensor parallel. "
86+
f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}"
87+
)
88+
else:
89+
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
90+
# This function only serves FileSystemSaver.
91+
continue
92+
if hvd_rank == 0:
93+
# FileSystemSaver works well at rank 0.
94+
continue
95+
# save Dynamic Embedding Parameters
96+
de_var.save_to_file_system(dirpath=de_dir,
97+
proc_size=hvd.size(),
98+
proc_rank=hvd.rank())
99+
# save optimizer parameters of Dynamic Embedding
100+
if include_optimizer is True:
101+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
102+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
103+
for de_opt_var in de_opt_vars:
104+
de_opt_var.save_to_file_system(dirpath=de_dir,
105+
proc_size=hvd.size(),
106+
proc_rank=hvd.rank())
107+
108+
if hvd is None:
109+
call_original_save_func()
110+
else:
111+
if hvd.rank() == 0:
112+
call_original_save_func()
113+
_traverse_emb_layers_and_save(hvd.rank())
114+
hvd.join() # Sync for avoiding data conflict
115+
116+
117+
def de_hvd_save_model(model,
118+
filepath,
119+
overwrite=True,
120+
include_optimizer=True,
121+
signatures=None,
122+
options=None,
123+
save_traces=True,
124+
*args,
125+
**kwargs):
126+
if keras_saved_model_save is not None:
127+
_save_handle = functools.partial(_de_keras_save_func,
128+
keras_original_save_func)
129+
else:
130+
_save_handle = functools.partial(_de_keras_save_func, tf_original_save_func)
131+
return _save_handle(model,
132+
filepath,
133+
overwrite,
134+
include_optimizer,
135+
signatures=signatures,
136+
options=options,
137+
save_traces=save_traces,
138+
*args,
139+
**kwargs)

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,11 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
321321
if hvd.rank() == 0:
322322
if os.path.exists(save_dir):
323323
shutil.rmtree(save_dir)
324-
hvd.broadcast(tensor=tf.constant(1),
325-
root_rank=0) # Sync for avoiding files conflict
326-
base_model.save(save_dir, options=save_options)
324+
hvd.join() # Sync for avoiding files conflict
325+
# base_model.save(save_dir, options=save_options)
326+
de.keras.models.de_hvd_save_model(base_model,
327+
save_dir,
328+
options=save_options)
327329
del base_model
328330
new_base_model = get_emb_sequential_model(
329331
de.keras.layers.HvdAllToAllEmbedding,
@@ -333,13 +335,11 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
333335
bp_v2=False,
334336
kv_creator=kv_creator,
335337
name='all2all_emb')
336-
hvd.broadcast(tensor=tf.constant(1),
337-
root_rank=0) # Sync for avoiding files conflict
338+
hvd.join() # Sync for avoiding files conflict
338339
new_base_model.load_weights(save_dir + '/variables/variables')
339340
new_a2aemb_size = new_base_model.layers[0].params.size()
340341
self.assertEqual(a2aemb_size, new_a2aemb_size)
341-
hvd.broadcast(tensor=tf.constant(1),
342-
root_rank=0) # Sync for avoiding files conflict
342+
hvd.join() # Sync for avoiding files conflict
343343

344344

345345
if __name__ == "__main__":

0 commit comments

Comments
 (0)