Skip to content

Commit 9d20bac

Browse files
update
1 parent 3855831 commit 9d20bac

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ def has_horovod() -> bool:
2929
def config():
3030
# callback calls hvd.rank() so we need to initialize horovod here
3131
hvd.init()
32+
print("Size: ", hvd.size())
3233
if has_horovod():
3334
print("Horovod is enabled.")
3435
if hvd.rank() > 0:
3536
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3637
# Horovod: pin GPU to be used to process local rank (one GPU per process)
37-
config_gpu(hvd.local_rank())
38+
# config_gpu(hvd.local_rank())
3839
else:
3940
config_gpu()
4041

@@ -460,8 +461,8 @@ def call(self, features):
460461
def get_dataset(batch_size=1):
461462
ds = tfds.load("movielens/1m-ratings",
462463
split="train",
463-
data_dir="~/dataset",
464-
download=True)
464+
# data_dir="~/dataset",
465+
download=False)
465466
features = ds.map(
466467
lambda x: {
467468
"movie_id":
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/bin/bash
22
rm -rf ./export_dir
3-
gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
3+
#gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
4+
gpu_num=5
45
export gpu_num
56
horovodrun -np $gpu_num python movielens-1m-keras-with-horovod.py --mode="train" --model_dir="./model_dir" --export_dir="./export_dir" \
6-
--steps_per_epoch=${1:-20000} --shuffle=${2:-True}
7+
--steps_per_epoch=${1:-10} --shuffle=${2:-False}

tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def _save_de_var(de_var, proc_size=1, proc_rank=0):
122122
proc_rank=proc_rank)
123123

124124
def _maybe_save_restrict_policy_params(var, proc_size=1, proc_rank=0):
125+
if not hasattr(var, "restrict_policy"):
126+
return
125127
if var.restrict_policy is not None:
126128
de_var = var.restrict_policy._restrict_var
127129
_save_de_var(de_var, proc_size=proc_size, proc_rank=proc_rank)

0 commit comments

Comments
 (0)