Skip to content

Commit 52ef375

Browse files
MoFHekarhdong
authored andcommitted
[fix] Rename DECheckpoint and de_save_model api to make them more looks like TF style.
1 parent 4866bbd commit 52ef375

File tree

6 files changed

+22
-19
lines changed

6 files changed

+22
-19
lines changed

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,12 +466,12 @@ def export_to_savedmodel(model, savedmodel_dir):
466466

467467
# TFRA modify the Keras save function with a patch.
468468
# !!!! Run save_model function in all rank !!!!
469-
de.keras.models.de_save_model(model,
470-
savedmodel_dir,
471-
overwrite=True,
472-
include_optimizer=True,
473-
save_traces=True,
474-
options=save_options)
469+
de.keras.models.save_model(model,
470+
savedmodel_dir,
471+
overwrite=True,
472+
include_optimizer=True,
473+
save_traces=True,
474+
options=save_options)
475475

476476

477477
def export_for_serving(model, export_dir):
@@ -521,7 +521,7 @@ def serve(*args, **kwargs):
521521

522522
# TFRA modify the Keras save function with a patch.
523523
# !!!! Run save_model function in all rank !!!!
524-
de.keras.models.de_save_model(
524+
de.keras.models.save_model(
525525
model,
526526
export_dir,
527527
overwrite=True,
@@ -572,7 +572,7 @@ def train():
572572
# horovod callback is used to broadcast the value generated by initializer of rank0.
573573
hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(
574574
root_rank=0)
575-
ckpt_callback = de.keras.callbacks.DEHvdModelCheckpoint(
575+
ckpt_callback = de.keras.callbacks.ModelCheckpoint(
576576
filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
577577
options=save_options)
578578
callbacks_list = [hvd_opt_init_callback, ckpt_callback]

docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ In addition, we also provide parameter initialization and save callback related
8181

8282
[`dynamic_embedding.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)
8383

84-
[`dynamic_embedding.keras.callbacks.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)
84+
[`dynamic_embedding.keras.callbacks.ModelCheckpoint.`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)
8585

86-
[`dynamic_embedding.keras.models.de_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)
86+
[`dynamic_embedding.keras.models.save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)
8787

88-
[`dynamic_embedding.train.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)
88+
[`dynamic_embedding.train.ModelCheckpoint.`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)
8989

9090
You could inherit the `HvdAllToAllEmbedding` class to implement a custom embedding
9191
layer with other fixed shape output.
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import layers
22
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import callbacks
3-
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import models
3+
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import models
4+
5+
setattr(models, 'save_model', models.de_save_model)
6+
setattr(callbacks, 'ModelCheckpoint', callbacks.DEHvdModelCheckpoint)

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,7 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
330330
shutil.rmtree(save_dir)
331331
hvd.join() # Sync for avoiding files conflict
332332
# base_model.save(save_dir, options=save_options)
333-
de.keras.models.de_save_model(base_model,
334-
save_dir,
335-
options=save_options)
333+
de.keras.models.save_model(base_model, save_dir, options=save_options)
336334
ckpt = de.train.DECheckpoint(
337335
my_model=base_model) # Test custom model key "my_model"
338336
ckpt.save(save_dir + '/ckpt/test')
@@ -542,7 +540,7 @@ def check_TFRADynamicEmbedding_directory(save_dir,
542540
np.sort(new_de_opt_compared[opt_v_name][2], axis=0))
543541

544542
extra_save_dir = self.get_temp_dir() + '/extra_save_dir'
545-
de.keras.models.de_save_model(new_model, extra_save_dir)
543+
de.keras.models.save_model(new_model, extra_save_dir)
546544
if hvd.rank() == 0:
547545
check_TFRADynamicEmbedding_directory(extra_save_dir)
548546
del new_opt
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from tensorflow_recommenders_addons.dynamic_embedding.python.train.saver import DEHvdSaver
22
from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DECheckpoint
3+
4+
Checkpoint = DECheckpoint

tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@
2626
from tensorflow.python.framework import constant_op
2727
from tensorflow.python.framework import ops
2828
try: # tf version >= 2.10.0
29-
from tensorflow.python.checkpoint.checkpoint import Checkpoint
29+
from tensorflow.python.checkpoint.checkpoint import Checkpoint as TFCheckpoint
3030
from tensorflow.python.checkpoint import restore as ckpt_base
3131
except:
32-
from tensorflow.python.training.tracking.util import Checkpoint
32+
from tensorflow.python.training.tracking.util import Checkpoint as TFCheckpoint
3333
from tensorflow.python.training.tracking import base as ckpt_base
3434
from tensorflow.python.lib.io import file_io
3535
from tensorflow.python.platform import tf_logging
3636

3737

38-
class DECheckpoint(Checkpoint):
38+
class DECheckpoint(TFCheckpoint):
3939
"""Overwrite tf.train.Saver class
4040
Calling the TF save API for all ranks causes file conflicts,
4141
so KV files other than rank0 need to be saved by calling the underlying API separately.

0 commit comments

Comments
 (0)