Skip to content

Commit b33afb8

Browse files
MoFHekarhdong
authored andcommitted
[fix] Modify function name de_hvd_save_model to de_save_model.
Also make user more easy to use de_save_model by writing fewer code. Fix bug that de_hvd_save_model and CheckpointManager were unable to use together. Which cause by CheckpointManager compatibility code in DE hack changing the storage path in DE saveable object when runtime. The modified storage path was not wrote back after Checkpoint saving, and then when call saved_model saving function it would dump to a unexpected directory which was set in DECheckpoint class.
1 parent 9515585 commit b33afb8

File tree

4 files changed

+189
-58
lines changed

4 files changed

+189
-58
lines changed

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,12 +466,12 @@ def export_to_savedmodel(model, savedmodel_dir):
466466

467467
# TFRA modify the Keras save function with a patch.
468468
# !!!! Run save_model function in all rank !!!!
469-
de.keras.models.de_hvd_save_model(model,
470-
savedmodel_dir,
471-
overwrite=True,
472-
include_optimizer=True,
473-
save_traces=True,
474-
options=save_options)
469+
de.keras.models.de_save_model(model,
470+
savedmodel_dir,
471+
overwrite=True,
472+
include_optimizer=True,
473+
save_traces=True,
474+
options=save_options)
475475

476476

477477
def export_for_serving(model, export_dir):
@@ -521,7 +521,7 @@ def serve(*args, **kwargs):
521521

522522
# TFRA modify the Keras save function with a patch.
523523
# !!!! Run save_model function in all rank !!!!
524-
de.keras.models.de_hvd_save_model(
524+
de.keras.models.de_save_model(
525525
model,
526526
export_dir,
527527
overwrite=True,

docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ In addition, we also provide parameter initialization and save callback related
8383

8484
[`dynamic_embedding.keras.callbacks.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)
8585

86-
[`dynamic_embedding.keras.models.de_hvd_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)
86+
[`dynamic_embedding.keras.models.de_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)
8787

8888
[`dynamic_embedding.train.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)
8989

tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow.python.keras.saving.saved_model import save as tf_saved_model_save
2727
from tensorflow.python.ops import array_ops
2828
from tensorflow.python.platform import tf_logging
29+
from tensorflow.python.saved_model.save_options import SaveOptions
2930

3031
tf_original_save_func = tf_saved_model_save.save
3132
if keras_saved_model_save is not None:
@@ -56,6 +57,11 @@ def _de_keras_save_func(original_save_func,
5657
except:
5758
hvd = None
5859

60+
if hvd is not None:
61+
filepath = hvd.broadcast_object(filepath,
62+
root_rank=0,
63+
name='de_hvd_broadcast_filepath')
64+
5965
call_original_save_func = functools.partial(
6066
original_save_func,
6167
model=model,
@@ -68,8 +74,9 @@ def _de_keras_save_func(original_save_func,
6874
*args,
6975
**kwargs)
7076

71-
def _traverse_emb_layers_and_save(hvd_rank):
72-
de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding")
77+
de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding")
78+
79+
def _check_saveable_and_redirect_new_de_dir():
7380
for var in model.variables:
7481
if not hasattr(var, "params"):
7582
continue
@@ -85,33 +92,50 @@ def _traverse_emb_layers_and_save(hvd_rank):
8592
"It will allow TFRA load KV files when Embedding tensor parallel. "
8693
f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}"
8794
)
88-
else:
89-
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
90-
# This function only serves FileSystemSaver.
91-
continue
92-
if hvd_rank == 0:
93-
# FileSystemSaver works well at rank 0.
94-
continue
95-
# save Dynamic Embedding Parameters
96-
de_var.save_to_file_system(dirpath=de_dir,
97-
proc_size=hvd.size(),
98-
proc_rank=hvd.rank())
99-
# save optimizer parameters of Dynamic Embedding
100-
if include_optimizer is True:
101-
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
102-
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
103-
for de_opt_var in de_opt_vars:
104-
de_opt_var.save_to_file_system(dirpath=de_dir,
105-
proc_size=hvd.size(),
106-
proc_rank=hvd.rank())
95+
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
96+
# This function only serves FileSystemSaver.
97+
continue
98+
# Redirect new de_dir
99+
if hasattr(de_var, 'saveable'):
100+
de_var.saveable._saver_config.save_path = de_dir
107101

102+
def _traverse_emb_layers_and_save(hvd_rank=0):
103+
for var in model.variables:
104+
if not hasattr(var, "params"):
105+
continue
106+
if not hasattr(var.params, "_created_in_class"):
107+
continue
108+
de_var = var.params
109+
a2a_emb = de_var._created_in_class
110+
if de_var._saveable_object_creator is not None:
111+
if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver):
112+
# This function only serves FileSystemSaver.
113+
continue
114+
# save optimizer parameters of Dynamic Embedding
115+
if include_optimizer is True:
116+
de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr(
117+
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
118+
for de_opt_var in de_opt_vars:
119+
de_opt_var.save_to_file_system(dirpath=de_dir,
120+
proc_size=hvd.size(),
121+
proc_rank=hvd.rank())
122+
if hvd_rank == 0:
123+
# FileSystemSaver works well at rank 0.
124+
continue
125+
# save Dynamic Embedding Parameters
126+
de_var.save_to_file_system(dirpath=de_dir,
127+
proc_size=hvd.size(),
128+
proc_rank=hvd.rank())
129+
130+
_check_saveable_and_redirect_new_de_dir()
108131
if hvd is None:
109132
call_original_save_func()
133+
_traverse_emb_layers_and_save(0)
110134
else:
111135
if hvd.rank() == 0:
112136
call_original_save_func()
113137
_traverse_emb_layers_and_save(hvd.rank())
114-
hvd.join() # Sync for avoiding data conflict
138+
hvd.join() # Sync for avoiding rank conflict
115139

116140

117141
def de_hvd_save_model(model,
@@ -123,11 +147,37 @@ def de_hvd_save_model(model,
123147
save_traces=True,
124148
*args,
125149
**kwargs):
150+
return de_save_model(model=model,
151+
filepath=filepath,
152+
overwrite=True,
153+
include_optimizer=True,
154+
signatures=None,
155+
options=None,
156+
save_traces=True,
157+
*args,
158+
**kwargs)
159+
160+
161+
def de_save_model(model,
162+
filepath,
163+
overwrite=True,
164+
include_optimizer=True,
165+
signatures=None,
166+
options=None,
167+
save_traces=True,
168+
*args,
169+
**kwargs):
126170
if keras_saved_model_save is not None:
127171
_save_handle = functools.partial(_de_keras_save_func,
128172
keras_original_save_func)
129173
else:
130174
_save_handle = functools.partial(_de_keras_save_func, tf_original_save_func)
175+
if options is None:
176+
options = SaveOptions(namespace_whitelist=['TFRA'])
177+
elif isinstance(options, SaveOptions) and hasattr(options,
178+
'namespace_whitelist'):
179+
options.namespace_whitelist.append('TFRA')
180+
131181
return _save_handle(model,
132182
filepath,
133183
overwrite,

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py

Lines changed: 109 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,9 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
327327
shutil.rmtree(save_dir)
328328
hvd.join() # Sync for avoiding files conflict
329329
# base_model.save(save_dir, options=save_options)
330-
de.keras.models.de_hvd_save_model(base_model,
331-
save_dir,
332-
options=save_options)
330+
de.keras.models.de_save_model(base_model,
331+
save_dir,
332+
options=save_options)
333333
ckpt = de.train.DECheckpoint(
334334
my_model=base_model) # Test custom model key "my_model"
335335
ckpt.save(save_dir + '/ckpt/test')
@@ -407,31 +407,38 @@ def call(self, x):
407407
return self.l2(out)
408408

409409
def check_TFRADynamicEmbedding_directory(save_dir,
410-
save_it,
410+
save_it=None,
411411
should_be_exist=True):
412412
hvd_size = hvd.size()
413413
if hvd_size <= 1:
414414
hvd_size = 1
415+
base_dir = os.path.join(save_dir, 'variables', 'TFRADynamicEmbedding')
416+
if save_it is not None:
417+
base_dir = os.path.join(save_dir, f'TFRADynamicEmbedding-{save_it}')
415418
for tag in ['keys', 'values']:
416419
for rank in range(hvd_size):
417420
self.assertTrue(not (os.path.exists(
418-
save_dir +
419-
f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
420-
) ^ should_be_exist))
421+
base_dir +
422+
f'/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}') ^
423+
should_be_exist))
421424
self.assertTrue(not (os.path.exists(
422-
save_dir +
423-
f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
425+
base_dir +
426+
f'/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
424427
) ^ should_be_exist))
425-
# f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
428+
# f'/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
426429
self.assertTrue(not (os.path.exists(
427-
save_dir +
428-
f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
430+
base_dir +
431+
f'/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
429432
) ^ should_be_exist))
430-
# f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
433+
# f'/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
431434

432435
with tf.device("/{}:{}".format(_device, _device_id)):
433436
x = tf.reshape(tf.range(0, 32, dtype=tf.int64), [32, 1])
434437
y = tf.random.uniform(shape=[32, 1])
438+
base_de_emb_standard = {}
439+
base_de_opt_standard = {}
440+
new_de_emb_compared = {}
441+
new_de_opt_compared = {}
435442

436443
save_dir = self.get_temp_dir()
437444

@@ -454,13 +461,16 @@ def check_TFRADynamicEmbedding_directory(save_dir,
454461
l.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim]))
455462
emb_size = l.params.size()
456463
emb_keys, emb_values = l.params.export()
464+
base_de_emb_standard[l.name] = (emb_size, emb_keys, emb_values)
457465
break
458466
for v in base_opt.variables():
459467
if name in v.name:
460468
v.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim]))
461469
opt_size = v.params.size()
462-
opt_keys, opt_values = l.params.export()
463-
break
470+
opt_keys, opt_values = v.params.export()
471+
base_de_opt_standard[v._shared_name.split('/')[-1]] = (opt_size,
472+
opt_keys,
473+
opt_values)
464474
manager.save()
465475
if hvd.rank() == 0:
466476
check_TFRADynamicEmbedding_directory(save_dir,
@@ -491,31 +501,102 @@ def check_TFRADynamicEmbedding_directory(save_dir,
491501
new_model.compile(optimizer=new_opt, loss='mean_absolute_error')
492502
new_model(x) # Build vairiables
493503
try:
494-
new_opt._create_all_weights(new_model.variables)
504+
new_opt._create_all_weights([
505+
new_model.variables[0]
506+
]) # Create DE slot variable from DE shadow variable
495507
except:
496508
#TODO(MoFHejia) raise ValueError: Cannot convert a partially known TensorShape <unknown> to a Tensor.
497509
pass
498510
for l in new_model.layers:
499511
if name in l.name:
500512
new_emb_size = l.params.size()
501513
new_emb_keys, new_emb_values = l.params.export()
514+
new_de_emb_compared[l.name] = (new_emb_size, new_emb_keys,
515+
new_emb_values)
502516
break
503517
for v in new_opt.variables():
504518
if name in v.name:
505519
new_opt_size = v.params.size()
506-
new_opt_keys, new_opt_values = l.params.export()
520+
new_opt_keys, new_opt_values = v.params.export()
521+
new_de_opt_compared[v._shared_name.split('/')[-1]] = (new_opt_size,
522+
new_opt_keys,
523+
new_opt_values)
524+
525+
for de_l_name in base_de_emb_standard.keys():
526+
self.assertEqual(base_de_emb_standard[de_l_name][0],
527+
new_de_emb_compared[de_l_name][0])
528+
self.assertAllEqual(np.sort(base_de_emb_standard[de_l_name][1], axis=0),
529+
np.sort(new_de_emb_compared[de_l_name][1], axis=0))
530+
self.assertAllClose(np.sort(base_de_emb_standard[de_l_name][2], axis=0),
531+
np.sort(new_de_emb_compared[de_l_name][2], axis=0))
532+
for opt_v_name in base_de_opt_standard.keys():
533+
self.assertEqual(base_de_opt_standard[opt_v_name][0],
534+
new_de_opt_compared[opt_v_name][0])
535+
self.assertAllEqual(
536+
np.sort(base_de_opt_standard[opt_v_name][1], axis=0),
537+
np.sort(new_de_opt_compared[opt_v_name][1], axis=0))
538+
self.assertAllClose(
539+
np.sort(base_de_opt_standard[opt_v_name][2], axis=0),
540+
np.sort(new_de_opt_compared[opt_v_name][2], axis=0))
541+
542+
extra_save_dir = self.get_temp_dir() + '/extra_save_dir'
543+
de.keras.models.de_save_model(new_model, extra_save_dir)
544+
if hvd.rank() == 0:
545+
check_TFRADynamicEmbedding_directory(extra_save_dir)
546+
del new_opt
547+
del new_model
548+
del new_ckpt
549+
tf.keras.backend.clear_session()
550+
tf.compat.v1.reset_default_graph()
551+
new_saved_model = NoCompileModel('zeros')
552+
new_saved_opt = Adam(1.2)
553+
new_saved_opt = de.DynamicEmbeddingOptimizer(new_saved_opt,
554+
synchronous=True)
555+
new_saved_model.compile(optimizer=new_saved_opt,
556+
loss='mean_absolute_error')
557+
new_saved_model(x) # Build vairiables
558+
try:
559+
new_opt._create_all_weights([
560+
new_model.variables[0]
561+
]) # Create DE slot variable from DE shadow variable
562+
except:
563+
#TODO(MoFHejia) raise ValueError: Cannot convert a partially known TensorShape <unknown> to a Tensor.
564+
pass
565+
extra_save_dir = hvd.broadcast_object(
566+
extra_save_dir, root_rank=0, name='de_utest_hvd_broadcast_filepath'
567+
) # All ranks should share same save directory
568+
new_saved_model.load_weights(extra_save_dir + '/variables/variables')
569+
for l in new_saved_model.layers:
570+
if name in l.name:
571+
new_emb_size = l.params.size()
572+
new_emb_keys, new_emb_values = l.params.export()
573+
new_de_emb_compared[l.name] = (new_emb_size, new_emb_keys,
574+
new_emb_values)
507575
break
508-
509-
self.assertEqual(emb_size, new_emb_size)
510-
self.assertEqual(opt_size, new_opt_size)
511-
self.assertAllEqual(np.sort(emb_keys, axis=0),
512-
np.sort(new_emb_keys, axis=0))
513-
self.assertAllClose(np.sort(emb_values, axis=0),
514-
np.sort(new_emb_values, axis=0))
515-
self.assertAllEqual(np.sort(opt_keys, axis=0),
516-
np.sort(new_opt_keys, axis=0))
517-
self.assertAllClose(np.sort(opt_values, axis=0),
518-
np.sort(new_opt_values, axis=0))
576+
for v in new_saved_opt.variables():
577+
if name in v.name:
578+
new_opt_size = v.params.size()
579+
new_opt_keys, new_opt_values = l.params.export()
580+
new_de_opt_compared[v._shared_name.split('/')[-1]] = (new_opt_size,
581+
new_opt_keys,
582+
new_opt_values)
583+
584+
for de_l_name in base_de_emb_standard.keys():
585+
self.assertEqual(base_de_emb_standard[de_l_name][0],
586+
new_de_emb_compared[de_l_name][0])
587+
self.assertAllEqual(np.sort(base_de_emb_standard[de_l_name][1], axis=0),
588+
np.sort(new_de_emb_compared[de_l_name][1], axis=0))
589+
self.assertAllClose(np.sort(base_de_emb_standard[de_l_name][2], axis=0),
590+
np.sort(new_de_emb_compared[de_l_name][2], axis=0))
591+
for opt_v_name in base_de_opt_standard.keys():
592+
self.assertEqual(base_de_opt_standard[opt_v_name][0],
593+
new_de_opt_compared[opt_v_name][0])
594+
self.assertAllEqual(
595+
np.sort(base_de_opt_standard[opt_v_name][1], axis=0),
596+
np.sort(new_de_opt_compared[opt_v_name][1], axis=0))
597+
self.assertAllClose(
598+
np.sort(base_de_opt_standard[opt_v_name][2], axis=0),
599+
np.sort(new_de_opt_compared[opt_v_name][2], axis=0))
519600

520601

521602
if __name__ == "__main__":

0 commit comments

Comments
 (0)