Skip to content

Conversation

jatinsharechat
Copy link

Description

  • Issue
    • The current saving logic only saves restrict policy parameters for rank == 0.
    • For rank != 0, the restrict_var is not restored, leading to an unsynchronized state between the embedding variable and restrict_var.
    • As a result, the restrict policy does not work correctly in distributed training.
  • Fixes
    • Added _maybe_save_restrict_policy_params function:
      • Checks if de_var has a restrict_policy.
      • Saves the associated restrict policy variable to the file system for each rank.
    • Updated _traverse_emb_layers_and_save to:
      • Call _maybe_save_restrict_policy_params for each distributed embedding variable (de_var).
      • Ensure restrict policy parameters are saved and restored for all Horovod ranks (hvd.rank()).

Type of change

  • Bug fix
  • New Tutorial
  • Updated or additional documentation
  • Additional Testing
  • New Feature

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running yapf
    • By running clang-format
  • This PR addresses an already submitted issue for TensorFlow Recommenders-Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works

How Has This Been Tested?

  • Written test case tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_embedding_restrict_save_test.py that trains a dummy model for some steps, then saves the model.
  • The model's embedding-table is created with restrict policy.
  • Test-case checks if the restrict-policy is saved for each rank.
  • To simulate distributed environment test-case was tested using horovodrun for CPU based distributed setup
    $ horovodrun -np 2 python tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_distributed_restrict_policy_save.py
    
  • Output of saved-model from above test-case:
    $ ls /tmp/hvd_distributed_restrict_policy_save_timestamp210/variables/TFRADynamicEmbedding/
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_m_mht_1of1_rank0_size2-keys    all2all_emb-parameter_mht_1of1_rank0_size2-keys
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_m_mht_1of1_rank0_size2-values  all2all_emb-parameter_mht_1of1_rank0_size2-values
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_m_mht_1of1_rank1_size2-keys    all2all_emb-parameter_mht_1of1_rank1_size2-keys
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_m_mht_1of1_rank1_size2-values  all2all_emb-parameter_mht_1of1_rank1_size2-values
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_v_mht_1of1_rank0_size2-keys    all2all_emb-parameter_timestamp_mht_1of1_rank0_size2-keys
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_v_mht_1of1_rank0_size2-values  all2all_emb-parameter_timestamp_mht_1of1_rank0_size2-values
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_v_mht_1of1_rank1_size2-keys    all2all_emb-parameter_timestamp_mht_1of1_rank1_size2-keys
    all2all_emb-parameter_DynamicEmbedding_all2all_emb-shadow_v_mht_1of1_rank1_size2-values  all2all_emb-parameter_timestamp_mht_1of1_rank1_size2-values
    
    

@jatinsharechat jatinsharechat requested a review from rhdong as a code owner April 3, 2025 09:58
Copy link

google-cla bot commented Apr 3, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jatinsharechat
Copy link
Author

Hey @rhdong, @jq, @MoFHeka, if I can get some help with review of this would be great! Thanks

@rhdong rhdong requested review from jq and MoFHeka and removed request for rhdong April 10, 2025 18:39
@rhdong
Copy link
Member

rhdong commented Apr 10, 2025

Hey @rhdong, @jq, @MoFHeka, if I can get some help with review of this would be great! Thanks

Hi @jatinsharechat , thanks for your contribution! The CLA needed to be signed; please follow the guidance: https://github.com/tensorflow/recommenders-addons/pull/491/checks?check_run_id=39908185786. cc @jq @MoFHeka

@jatinsharechat
Copy link
Author

Hey @rhdong, @jq, @MoFHeka, if I can get some help with review of this would be great! Thanks

Hi @jatinsharechat , thanks for your contribution! The CLA needed to be signed; please follow the guidance: https://github.com/tensorflow/recommenders-addons/pull/491/checks?check_run_id=39908185786. cc @jq @MoFHeka

I've signed off the CLA and the rescan is green.

rhdong
rhdong previously approved these changes Apr 10, 2025
Copy link
Member

@rhdong rhdong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trigger CI

@jatinsharechat jatinsharechat requested a review from rhdong April 11, 2025 05:19
@jq
Copy link
Collaborator

jq commented Apr 15, 2025

the code format is failing, you may run yapf


def _save_de_model(self, filepath):

def _maybe_save_restrict_policy_params(de_var, proc_size=1, proc_rank=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use one _maybe_save_restrict_policy_params?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the code is pretty minimal and calling de_var.save_to_file_system under the hood I thought might be okay to replicate the same function.

Any suggestions where to move the util function to share between the two? Just import from tensorflow_recommenders_addons.dynamic_embedding.python.keras.models._maybe_save_restrict_policy_params in callbacks.py or and use or something else?

@jatinsharechat jatinsharechat requested a review from jq April 21, 2025 05:15
@jatinsharechat
Copy link
Author

Gentle ping on this @jq @rhdong

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants