Skip to content

Commit e8abbe6

Browse files
Add test-case for restrict policy save
1 parent e53c069 commit e8abbe6

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

pytest.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pytest~=6.2.5
2+
pytest-xdist~=1.31
3+
pytest-extra-durations~=0.1.3
4+
scikit-learn<=1.2.2
5+
scikit-image<=0.20.0
6+
Pillow~=9.4.0
7+
tqdm>=4.36.1
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
unit tests of save model that uses HvdAllToAllEmbedding
3+
"""
4+
from __future__ import absolute_import
5+
from __future__ import division
6+
from __future__ import print_function
7+
8+
import os
9+
import shutil
10+
from time import sleep
11+
12+
import tensorflow as tf
13+
14+
from tensorflow_recommenders_addons import dynamic_embedding as de
15+
16+
from tensorflow.python.framework import dtypes
17+
from tensorflow.python.framework.errors_impl import NotFoundError
18+
from tensorflow.python.ops import math_ops
19+
from tensorflow.python.platform import test
20+
21+
try:
22+
from tf_keras import layers, Sequential, models, backend
23+
from tf_keras.initializers import Zeros
24+
from tf_keras.optimizers import Adam
25+
except:
26+
from tensorflow.keras import layers, Sequential, models, backend
27+
from tensorflow.keras.initializers import Zeros
28+
try:
29+
from tensorflow.keras.optimizers import Adam
30+
except:
31+
from tensorflow.keras.legacy.optimizers import Adam
32+
33+
34+
def get_all_to_all_emb_model(emb_t, opt, *args, **kwargs):
35+
l0 = layers.InputLayer(input_shape=(None,), dtype=dtypes.int64)
36+
l1 = emb_t(*args, **kwargs)
37+
l2 = layers.Dense(8, 'relu', kernel_initializer='zeros')
38+
l3 = layers.Dense(1, 'sigmoid', kernel_initializer='zeros')
39+
if emb_t == de.keras.layers.HvdAllToAllEmbedding:
40+
model = Sequential([l0, l1, l2, l3])
41+
else:
42+
raise TypeError('Unsupported embedding layer {}'.format(emb_t))
43+
44+
model.compile(optimizer=opt, loss='mean_absolute_error')
45+
return model
46+
47+
48+
class HorovodAllToAllRestrictPolicyTest(test.TestCase):
49+
def test_all_to_all_embedding_restrict_policy_save(self):
50+
try:
51+
import horovod.tensorflow as hvd
52+
except (NotFoundError):
53+
self.skipTest(
54+
"Skip the test for horovod import error with Tensorflow-2.7.0 on MacOS-12."
55+
)
56+
57+
hvd.init()
58+
59+
name = "all2all_emb"
60+
keras_base_opt = Adam(1.0)
61+
base_opt = de.DynamicEmbeddingOptimizer(keras_base_opt, synchronous=True)
62+
63+
init = Zeros()
64+
kv_creator = de.CuckooHashTableCreator(
65+
saver=de.FileSystemSaver(proc_size=hvd.size(), proc_rank=hvd.rank()))
66+
batch_size = 8
67+
start = 0
68+
dim = 10
69+
run_step = 10
70+
71+
save_dir = "/tmp/hvd_distributed_restrict_policy_save" + str(
72+
hvd.size()) + str(
73+
dim) # All ranks should share same save directory
74+
75+
base_model = get_all_to_all_emb_model(
76+
de.keras.layers.HvdAllToAllEmbedding,
77+
base_opt,
78+
embedding_size=dim,
79+
initializer=init,
80+
bp_v2=False,
81+
kv_creator=kv_creator,
82+
restrict_policy=de.TimestampRestrictPolicy, # Embedding table with restrict policy
83+
name='all2all_emb')
84+
85+
for i in range(1, run_step):
86+
x = math_ops.range(start, start + batch_size, dtype=dtypes.int64)
87+
x = tf.reshape(x, (batch_size, -1))
88+
start += batch_size
89+
y = tf.zeros((batch_size, 1), dtype=dtypes.float32)
90+
base_model.fit(x, y, verbose=0)
91+
92+
save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
93+
if hvd.rank() == 0:
94+
if os.path.exists(save_dir):
95+
shutil.rmtree(save_dir)
96+
hvd.join() # Sync for avoiding files conflict
97+
base_model.save(save_dir, options=save_options)
98+
de.keras.models.save_model(base_model, save_dir, options=save_options)
99+
100+
sleep(4) # Wait for filesystem operation
101+
hvd_size = hvd.size()
102+
if hvd_size <= 1:
103+
hvd_size = 1
104+
base_dir = os.path.join(save_dir, "variables", "TFRADynamicEmbedding")
105+
for tag in ['keys', 'values']:
106+
for rank in range(hvd_size):
107+
self.assertTrue(os.path.exists(
108+
base_dir +
109+
f'/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}'))
110+
self.assertTrue(os.path.exists(
111+
base_dir +
112+
f'/{name}-parameter_DynamicEmbedding_{name}-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
113+
))
114+
self.assertTrue(os.path.exists(
115+
base_dir +
116+
f'/{name}-parameter_DynamicEmbedding_{name}-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
117+
))
118+
# Restrict policy var saved for all ranks
119+
self.assertTrue(os.path.exists(
120+
base_dir +
121+
f'/{name}-parameter_timestamp_mht_1of1_rank{rank}_size{hvd_size}-{tag}'))
122+
123+
124+
if __name__ == "__main__":
125+
test.main()

0 commit comments

Comments
 (0)