diff --git a/.github/workflows/make_wheel_macOS_arm64.sh b/.github/workflows/make_wheel_macOS_arm64.sh index 3607d0e5..af23c3f5 100644 --- a/.github/workflows/make_wheel_macOS_arm64.sh +++ b/.github/workflows/make_wheel_macOS_arm64.sh @@ -7,6 +7,7 @@ export TF_NEED_CUDA=0 export IGNORE_HKV="--ignore=./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_ops_test.py" export IGNORE_REDIS="--ignore=./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/redis_table_ops_test.py" export IGNORE_REDIS_VAR="--ignore=./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/redis_table_variable_test.py" +export IGNORE_HOROVOD_DIST_TRAINING_TEST="--ignore=./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py" export USE_BAZEL_VERSION='5.1.1' # For TensorFlow version 2.12 or earlier: @@ -59,7 +60,7 @@ delocate-wheel -w wheelhouse -v --ignore-missing-dependencies artifacts/*.whl # Test pip install --default-timeout=1000 -r tools/install_deps/pytest.txt cp ./bazel-bin/tensorflow_recommenders_addons/dynamic_embedding/core/_*_ops.so ./tensorflow_recommenders_addons/dynamic_embedding/core/ -python -m pytest -v -s --functions-durations=20 --modules-durations=5 $IGNORE_HKV $IGNORE_REDIS $IGNORE_REDIS_VAR $SKIP_CUSTOM_OP_TESTS ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ +python -m pytest -v -s --functions-durations=20 --modules-durations=5 $IGNORE_HKV $IGNORE_HOROVOD_DIST_TRAINING_TEST $IGNORE_REDIS $IGNORE_REDIS_VAR $SKIP_CUSTOM_OP_TESTS ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ # Clean bazel clean \ No newline at end of file diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py index a6a6d9fc..2339873d 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py @@ -113,6 +113,17 @@ def __init__(self, *args, **kwargs): super(DEHvdModelCheckpoint, self).__init__(*args, **kwargs) def _save_de_model(self, filepath): + + def _maybe_save_restrict_policy_params(de_var, proc_size=1, proc_rank=0): + if not hasattr(de_var, "restrict_policy"): + return + if de_var.restrict_policy is not None: + # Only save restrict policy var if policy created + de_var = de_var.restrict_policy._restrict_var + de_var.save_to_file_system(dirpath=de_dir, + proc_size=proc_size, + proc_rank=proc_rank) + if hvd.rank() == 0: if self.save_weights_only: self.model.save_weights(filepath, overwrite=True, options=self._options) @@ -143,6 +154,11 @@ def _save_de_model(self, filepath): de_opt_var.save_to_file_system(dirpath=de_dir, proc_size=hvd.size(), proc_rank=hvd.rank()) + + # Save restrict policy for each hvd.rank() + _maybe_save_restrict_policy_params(de_var, + proc_size=hvd.size(), + proc_rank=hvd.rank()) hvd.join() # Sync for avoiding data conflict or missing rank def _save_model(self, epoch, logs): diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py index bbec4da2..26d8af29 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py @@ -99,6 +99,16 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0): if hasattr(de_var, 'saveable'): de_var.saveable._saver_config.save_path = de_dir + def _maybe_save_restrict_policy_params(de_var, proc_size=1, proc_rank=0): + if not hasattr(de_var, "restrict_policy"): + return + if de_var.restrict_policy is not None: + # Only save restrict policy var if policy created + de_var = de_var.restrict_policy._restrict_var + de_var.save_to_file_system(dirpath=de_dir, + proc_size=proc_size, + proc_rank=proc_rank) + def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0): for var in model.variables: if not hasattr(var, "params"): @@ -126,6 +136,10 @@ def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0): de_var.save_to_file_system(dirpath=de_dir, proc_size=proc_size, proc_rank=proc_rank) + # Save restrict policy for each hvd.rank() + _maybe_save_restrict_policy_params(de_var, + proc_size=proc_size, + proc_rank=proc_rank) if hvd is None: call_original_save_func() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py new file mode 100644 index 00000000..afbbd8d2 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py @@ -0,0 +1,131 @@ +""" +unit tests of save model that uses HvdAllToAllEmbedding +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +from time import sleep + +import tensorflow as tf + +from tensorflow_recommenders_addons import dynamic_embedding as de + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework.errors_impl import NotFoundError +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + +try: + from tf_keras import layers, Sequential, models, backend + from tf_keras.initializers import Zeros + from tf_keras.optimizers import Adam +except: + from tensorflow.keras import layers, Sequential, models, backend + from tensorflow.keras.initializers import Zeros + try: + from tensorflow.keras.optimizers import Adam + except: + from tensorflow.keras.legacy.optimizers import Adam + + +def get_all_to_all_emb_model(emb_t, opt, *args, **kwargs): + l0 = layers.InputLayer(input_shape=(None,), dtype=dtypes.int64) + l1 = emb_t(*args, **kwargs) + l2 = layers.Dense(8, 'relu', kernel_initializer='zeros') + l3 = layers.Dense(1, 'sigmoid', kernel_initializer='zeros') + if emb_t == de.keras.layers.HvdAllToAllEmbedding: + model = Sequential([l0, l1, l2, l3]) + else: + raise TypeError('Unsupported embedding layer {}'.format(emb_t)) + + model.compile(optimizer=opt, loss='mean_absolute_error') + return model + + +class HorovodAllToAllRestrictPolicyTest(test.TestCase): + + def test_all_to_all_embedding_restrict_policy_save(self): + try: + import horovod.tensorflow as hvd + except (NotFoundError): + self.skipTest( + "Skip the test for horovod import error with Tensorflow-2.7.0 on MacOS-12." + ) + + hvd.init() + + name = "all2all_emb" + keras_base_opt = Adam(1.0) + base_opt = de.DynamicEmbeddingOptimizer(keras_base_opt, synchronous=True) + + init = Zeros() + kv_creator = de.CuckooHashTableCreator( + saver=de.FileSystemSaver(proc_size=hvd.size(), proc_rank=hvd.rank())) + batch_size = 8 + start = 0 + dim = 10 + run_step = 10 + + save_dir = "/tmp/hvd_distributed_restrict_policy_save" + str( + hvd.size()) + str(dim) # All ranks should share same save directory + + base_model = get_all_to_all_emb_model( + de.keras.layers.HvdAllToAllEmbedding, + base_opt, + embedding_size=dim, + initializer=init, + bp_v2=False, + kv_creator=kv_creator, + restrict_policy=de. + TimestampRestrictPolicy, # Embedding table with restrict policy + name='all2all_emb') + + for i in range(1, run_step): + x = math_ops.range(start, start + batch_size, dtype=dtypes.int64) + x = tf.reshape(x, (batch_size, -1)) + start += batch_size + y = tf.zeros((batch_size, 1), dtype=dtypes.float32) + base_model.fit(x, y, verbose=0) + + save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) + if hvd.rank() == 0: + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + hvd.join() # Sync for avoiding files conflict + base_model.save(save_dir, options=save_options) + de.keras.models.save_model(base_model, save_dir, options=save_options) + + sleep(4) # Wait for filesystem operation + hvd_size = hvd.size() + if hvd_size <= 1: + hvd_size = 1 + base_dir = os.path.join(save_dir, "variables", "TFRADynamicEmbedding") + for tag in ['keys', 'values']: + for rank in range(hvd_size): + self.assertTrue( + os.path.exists( + base_dir + + f'/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}')) + self.assertTrue( + os.path.exists( + base_dir + + f'/{name}-parameter_DynamicEmbedding_{name}-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + )) + self.assertTrue( + os.path.exists( + base_dir + + f'/{name}-parameter_DynamicEmbedding_{name}-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + )) + # Restrict policy var saved for all ranks + self.assertTrue( + os.path.exists( + base_dir + + f'/{name}-parameter_timestamp_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + )) + + +if __name__ == "__main__": + test.main() diff --git a/tools/docker/cpu_tests.Dockerfile b/tools/docker/cpu_tests.Dockerfile index 93adefe9..210274ea 100644 --- a/tools/docker/cpu_tests.Dockerfile +++ b/tools/docker/cpu_tests.Dockerfile @@ -35,7 +35,7 @@ RUN python configure.py RUN pip install -e ./ RUN --mount=type=cache,id=cache_bazel,target=/root/.cache/bazel \ bash tools/install_so_files.sh -RUN pytest -v -s -n auto --durations=25 --ignore-glob="*/hkv_hashtable_ops_test.py" --doctest-modules ./tensorflow_recommenders_addons \ +RUN pytest -v -s -n auto --durations=25 --ignore-glob="*/hkv_hashtable_ops_test.py" --ignore-glob="*/horovod_embedding_restrict_save_test.py" --doctest-modules ./tensorflow_recommenders_addons \ --cov=tensorflow_recommenders_addons ./tensorflow_recommenders_addons/ RUN bazel build --enable_runfiles build_pip_pkg diff --git a/tools/testing/build_and_run_tests.sh b/tools/testing/build_and_run_tests.sh index 9d2a6553..d587dc91 100644 --- a/tools/testing/build_and_run_tests.sh +++ b/tools/testing/build_and_run_tests.sh @@ -60,7 +60,7 @@ if [ "$TF_NEED_CUDA" -ne 0 ]; then bash /install/install_horovod.sh $HOROVOD_VERSION --only-cpu fi # TODO(jamesrong): Test on GPU. - CUDA_VISIBLE_DEVICES="" mpirun -np 2 -H localhost:2 --allow-run-as-root pytest -v ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py + CUDA_VISIBLE_DEVICES="" mpirun -np 2 -H localhost:2 --allow-run-as-root pytest -v ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py # Reinstall Horovod after tests if [ "$(uname)" != "Darwin" ]; then # Mac only with MPI @@ -74,12 +74,15 @@ if [ "$TF_NEED_CUDA" -eq 0 ]; then IGNORE_HKV="--ignore=./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_ops_test.py" fi +# Test only with horovod on GPU +IGNORE_HOROVOD_DIST_TRAINING_TEST="--ignore=./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py" + # Only use GPU 0 if available. if [ -x "$(command -v nvidia-smi)" ]; then export CUDA_VISIBLE_DEVICES=0 fi -python -m pytest -v -s --functions-durations=20 --modules-durations=5 $IGNORE_HKV $SKIP_CUSTOM_OP_TESTS_FLAG $EXTRA_ARGS ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ +python -m pytest -v -s --functions-durations=20 --modules-durations=5 $IGNORE_HKV $IGNORE_HOROVOD_DIST_TRAINING_TEST $SKIP_CUSTOM_OP_TESTS_FLAG $EXTRA_ARGS ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ # Release disk space bazel clean --expunge