Skip to content

Commit f325fbb

Browse files
committed
ut for ps
1 parent b3bc3d4 commit f325fbb

File tree

13 files changed

+390
-64
lines changed

13 files changed

+390
-64
lines changed

demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import tensorflow as tf
33
import tensorflow_datasets as tfds
44

5-
from absl import flags
6-
from absl import app
75
from tensorflow_recommenders_addons import dynamic_embedding as de
6+
87
try:
98
from tensorflow.keras.optimizers.legacy import Adam
109
except:
1110
from tensorflow.keras.optimizers import Adam
1211

12+
from tensorflow import distribute as tf_dist
13+
1314
flags = tf.compat.v1.app.flags
1415
FLAGS = flags.FLAGS
1516
flags.DEFINE_string(
@@ -34,6 +35,18 @@
3435
], dtype=tf.int64, name='movie_id')
3536
}
3637

38+
gpus = tf.config.list_physical_devices('GPU')
39+
if gpus:
40+
try:
41+
# Currently, memory growth needs to be the same across GPUs
42+
for gpu in gpus:
43+
tf.config.experimental.set_memory_growth(gpu, True)
44+
logical_gpus = tf.config.list_logical_devices('GPU')
45+
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
46+
except RuntimeError as e:
47+
# Memory growth must be set before GPUs have been initialized
48+
print(e)
49+
3750

3851
class DualChannelsDeepModel(tf.keras.Model):
3952

@@ -59,11 +72,13 @@ def __init__(self,
5972
user_embedding_size,
6073
initializer=embedding_initializer,
6174
devices=self.devices,
75+
with_unique=False,
6276
name='user_embedding')
6377
self.movie_embedding = de.keras.layers.SquashedEmbedding(
6478
movie_embedding_size,
6579
initializer=embedding_initializer,
6680
devices=self.devices,
81+
with_unique=False,
6782
name='movie_embedding')
6883

6984
self.dnn1 = tf.keras.layers.Dense(
@@ -94,7 +109,6 @@ def call(self, features):
94109
user_latent = self.user_embedding(user_id)
95110
movie_latent = self.movie_embedding(movie_id)
96111
latent = tf.concat([user_latent, movie_latent], axis=1)
97-
98112
x = self.dnn1(latent)
99113
x = self.dnn2(x)
100114
x = self.dnn3(x)
@@ -208,6 +222,7 @@ def test(self):
208222

209223
dataset = self.get_dataset(batch_size=self.test_bs)
210224
dataset = self.strategy.experimental_distribute_dataset(dataset)
225+
211226
with self.strategy.scope():
212227
model = tf.keras.models.load_model(self.export_dir)
213228
signature = model.signatures['serving_default']
@@ -237,13 +252,12 @@ def start_chief(config):
237252
cluster_spec = tf.train.ClusterSpec(config["cluster"])
238253
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
239254
cluster_spec, task_type="chief", task_id=0)
240-
strategy = tf.distribute.experimental.ParameterServerStrategy(
241-
cluster_resolver)
255+
strategy = tf_dist.experimental.ParameterServerStrategy(cluster_resolver)
242256
runner = Runner(strategy=strategy,
243257
train_bs=64,
244258
test_bs=1,
245-
epochs=2,
246-
steps_per_epoch=10,
259+
epochs=1,
260+
steps_per_epoch=1000,
247261
model_dir=None,
248262
export_dir=None)
249263
runner.train()

tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def on_batch_end(self, batch, logs=None):
6868
with ops.device(self.device):
6969
if hvd._executing_eagerly() and hasattr(self.model, 'variables'):
7070
# TensorFlow 2.0 or TensorFlow eager
71+
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import is_de_resource_variable
7172
filter_lambda = lambda x: (x.ref() not in self._local_vars) and (
72-
not isinstance(x, de.TrainableWrapper)) and (not isinstance(
73-
x, de.DEResourceVariable))
73+
not is_de_resource_variable(x))
7474
broadcast_vars = [
7575
var for var in self.model.variables if filter_lambda(var)
7676
]

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Dynamic Embedding is designed for Large-scale Sparse Weights Training.
1818
See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237)
1919
"""
20-
2120
from packaging import version
2221

2322
import tensorflow as tf
@@ -29,6 +28,8 @@
2928
from tensorflow.python.keras.utils import tf_utils
3029

3130
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import HvdVariable
31+
from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import \
32+
is_parameter_server_strategy
3233

3334
if version.parse(tf.__version__) >= version.parse("2.14"):
3435
from tensorflow.python.distribute import distribute_lib as distribute_ctx
@@ -225,7 +226,8 @@ def __init__(self,
225226
shadow_name = name + '-shadow' if name else 'ShadowVariable'
226227
if distribute_ctx.has_strategy():
227228
self.distribute_strategy = distribute_ctx.get_strategy()
228-
if self.distribute_strategy:
229+
if self.distribute_strategy and not is_parameter_server_strategy(
230+
self.distribute_strategy):
229231
strategy_devices = self.distribute_strategy.extended.worker_devices
230232
self.shadow_impl = tf_utils.ListWrapper([])
231233
for i, strategy_device in enumerate(strategy_devices):
@@ -242,12 +244,23 @@ def __init__(self,
242244
trainable=trainable,
243245
distribute_strategy=self.distribute_strategy))
244246
else:
245-
self.shadow_impl = tf_utils.ListWrapper([
246-
de.shadow_ops.ShadowVariable(self.params,
247-
name=shadow_name,
248-
max_norm=self.max_norm,
249-
trainable=trainable)
250-
])
247+
if is_parameter_server_strategy(self.distribute_strategy):
248+
self.shadow_impl = tf_utils.ListWrapper([
249+
de.shadow_ops.ShadowVariable(
250+
self.params,
251+
name=shadow_name,
252+
max_norm=self.max_norm,
253+
distribute_strategy=self.distribute_strategy,
254+
trainable=trainable)
255+
])
256+
else:
257+
self.shadow_impl = tf_utils.ListWrapper([
258+
de.shadow_ops.ShadowVariable(self.params,
259+
name=shadow_name,
260+
max_norm=self.max_norm,
261+
trainable=trainable)
262+
])
263+
251264
if len(self.shadow_impl.as_list()) > 1:
252265
self._current_ids = data_structures.NoDependency(
253266
[shadow_i.ids for shadow_i in self.shadow_impl.as_list()])
@@ -261,24 +274,25 @@ def __init__(self,
261274
self._current_exists = data_structures.NoDependency(
262275
self.shadow_impl.as_list()[0].exists)
263276
self.optimizer_vars = self.shadow_impl.as_list()[0]._optimizer_vars
264-
if distribute_ctx.has_strategy(
265-
) and self.distribute_strategy and 'OneDeviceStrategy' not in str(
266-
self.distribute_strategy) and not values_util.is_saving_non_distributed(
267-
) and values_util.get_current_replica_id_as_int() is not None:
277+
if distribute_ctx.has_strategy() and self.distribute_strategy and \
278+
'OneDeviceStrategy' not in str(self.distribute_strategy) and \
279+
not values_util.is_saving_non_distributed() and \
280+
values_util.get_current_replica_id_as_int() is not None:
268281
self.shadow = de.DistributedVariableWrapper(
269282
self.distribute_strategy, self.shadow_impl.as_list(),
270283
VariableAggregation.NONE,
271284
TrainableWrapperDistributedPolicy(VariableAggregation.NONE))
272285
else:
273286
self.shadow = self.shadow_impl.as_list()[0]
287+
274288
self.params._created_in_class = self # To facilitate access to the primitive class through params
275289
super(Embedding, self).__init__(name=name,
276290
trainable=trainable,
277291
dtype=value_dtype)
278292

279293
def call(self, ids):
280294
"""
281-
Compute embedding output for feature ids. The output shape will be (shape(ids),
295+
Compute embedding output for feature ids. The output shape will be (shape(ids),
282296
embedding_size).
283297
284298
Args:

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ py_test(
7272
],
7373
)
7474

75+
# This test is not for pytest, it requires
76+
# bazel test //tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests:parameter_server_bzl
77+
py_test(
78+
name = "parameter_server_bzl",
79+
srcs = ["parameter_server_bzl.py"],
80+
python_version = "PY3",
81+
srcs_version = "PY3",
82+
deps = [
83+
"//tensorflow_recommenders_addons",
84+
],
85+
)
86+
7587
# This test will be banned by GitHub and cause account violations, please run the test manually locally.
7688
# py_test(
7789
# name = "redis_table_variable_test",
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# pytest: skip
2+
import os
3+
import sys
4+
5+
from absl.testing import parameterized
6+
from tensorflow.python.distribute import multi_process_lib
7+
import multiprocessing
8+
import tensorflow as tf
9+
from tensorflow.python.framework import constant_op
10+
11+
from tensorflow.python.training import server_lib
12+
13+
from tensorflow_recommenders_addons import dynamic_embedding as de
14+
15+
import numpy as np
16+
from tensorflow.python.compat import v2_compat
17+
from tensorflow.python.distribute import multi_process_runner
18+
from tensorflow.python.distribute import multi_worker_test_base
19+
from tensorflow.python.distribute import parameter_server_strategy_v2
20+
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
21+
22+
from tensorflow.python.eager import test
23+
from packaging import version
24+
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
25+
from tensorflow.python.eager import def_function
26+
from tensorflow.python.ops import variables
27+
28+
if version.parse(tf.__version__) >= version.parse("2.16"):
29+
from tf_keras import layers
30+
from tf_keras import Sequential
31+
from tf_keras.optimizers import Adam
32+
else:
33+
from tensorflow.python.keras import layers
34+
from tensorflow.python.keras import Sequential
35+
try:
36+
from tensorflow.keras.optimizers import Adam
37+
except:
38+
from tensorflow.keras.optimizers.legacy import Adam
39+
40+
41+
def create_multi_process_cluster(cluster_spec,
42+
rpc_layer='grpc',
43+
stream_output=False,
44+
collective_leader=None):
45+
46+
cluster = multi_worker_test_base.MultiProcessCluster(
47+
cluster_resolver_lib.SimpleClusterResolver(
48+
server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer),
49+
stream_output=stream_output,
50+
collective_leader=collective_leader)
51+
cluster.start()
52+
return cluster
53+
54+
55+
class ParameterServerStrategyV2Test(test.TestCase):
56+
57+
@classmethod
58+
def setUpClass(cls):
59+
super(ParameterServerStrategyV2Test, cls).setUpClass()
60+
cluster_spec = {
61+
"worker": ["localhost:2223", "localhost:2224"],
62+
"ps": ["localhost:2222"]
63+
}
64+
cls.cluster = create_multi_process_cluster(cluster_spec)
65+
cls.cluster_resolver = cls.cluster.cluster_resolver
66+
# cls.strategy = DEParameterServerStrategy(cls.cluster_resolver)
67+
cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
68+
cls.cluster_resolver)
69+
cls.coordinator = coordinator_lib.ClusterCoordinator(cls.strategy)
70+
71+
@classmethod
72+
def tearDownClass(cls):
73+
super(ParameterServerStrategyV2Test, cls).tearDownClass()
74+
cls.cluster.stop()
75+
76+
#@parameterized.parameters(True, False)
77+
def testPerWorkerVariableCreation(self):
78+
var_dtype = tf.dtypes.float32
79+
var_name = 'var'
80+
shape = [1] #if define_shape else None
81+
82+
# with self.strategy.scope():
83+
var = variables.Variable(initial_value=[0.0],
84+
shape=shape,
85+
dtype=var_dtype,
86+
name=var_name,
87+
per_worker_de_variable=True)
88+
89+
# Use per-worker variable as a capture
90+
@def_function.function
91+
def worker_fn():
92+
var.assign_add(constant_op.constant([1.0]))
93+
return var
94+
95+
num_closures = 10
96+
for ix in range(num_closures):
97+
self.coordinator.schedule(worker_fn)
98+
# Read the PWV many times to ensure result is up-to-date
99+
self.coordinator.join()
100+
result_sum = sum(var.read_all()).numpy()
101+
self.assertEqual(result_sum, ix + 1)
102+
103+
for _ in range(num_closures):
104+
self.coordinator.schedule(worker_fn)
105+
self.coordinator.join()
106+
107+
# Verify placement of variables
108+
devices = [wv._get_values().device for wv in var._per_worker_vars._values]
109+
expected_devices = [
110+
f'/job:worker/replica:0/task:{ix}/device:CPU:0'
111+
for ix in range(self.strategy._num_workers)
112+
] # pylint: disable=protected-access
113+
self.assertAllEqual(devices, expected_devices)
114+
115+
result_sum = sum(var.read_all()).numpy()
116+
self.assertEqual(result_sum, num_closures * 2)
117+
118+
def testKerasFit(self):
119+
embed_dim = 8
120+
with self.strategy.scope():
121+
model = Sequential([
122+
layers.Input(shape=(1,), dtype=tf.int32),
123+
de.keras.layers.Embedding(embed_dim, key_dtype=tf.int32),
124+
layers.Flatten(),
125+
layers.Dense(1, activation='sigmoid')
126+
])
127+
optimizer = Adam(1E-3)
128+
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
129+
model.compile(loss='binary_crossentropy',
130+
optimizer=optimizer,
131+
metrics=['accuracy'])
132+
133+
ids = np.random.randint(0, 100, size=(64 * 2, 1))
134+
labels = np.random.randint(0, 2, size=(64 * 2, 1))
135+
136+
def dataset_fn(input_context):
137+
global_batch_size = 32
138+
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
139+
dataset = tf.data.Dataset.from_tensor_slices((ids, labels))
140+
dataset = dataset.shard(input_context.num_input_pipelines,
141+
input_context.input_pipeline_id)
142+
dataset = dataset.batch(batch_size).repeat()
143+
return dataset
144+
145+
dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
146+
147+
history = model.fit(dataset, epochs=1, steps_per_epoch=len(ids) // 64)
148+
self.assertIn('loss', history.history)
149+
150+
151+
# borrow from multi_process_lib._set_spawn_exe_path and modify it for tf_recommenders_addons
152+
def custom_set_spawn_exe_path():
153+
if sys.argv[0].endswith('.py'):
154+
155+
def guess_path(package_root):
156+
# If all we have is a python module path, we'll need to make a guess for
157+
# the actual executable path.
158+
if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
159+
package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
160+
binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
161+
print(f"package_root_base {package_root_base} binary {binary}")
162+
possible_path = os.path.join(package_root_base, package_root, binary)
163+
print('Guessed test binary path: %s', possible_path)
164+
if os.access(possible_path, os.X_OK):
165+
return possible_path
166+
return None
167+
168+
path = guess_path('tf_recommenders_addons')
169+
if path is None:
170+
print('Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
171+
sys.argv[0], os.environ)
172+
raise RuntimeError('Cannot determine binary path')
173+
sys.argv[0] = path
174+
# Note that this sets the executable for *all* contexts.
175+
multiprocessing.get_context().set_executable(sys.argv[0])
176+
177+
178+
# This is not for pytest
179+
# bazel test //tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests:parameter_server_bzl
180+
if __name__ == "__main__":
181+
multi_process_lib._set_spawn_exe_path = custom_set_spawn_exe_path
182+
v2_compat.enable_v2_behavior()
183+
multi_process_runner.test_main()

tensorflow_recommenders_addons/dynamic_embedding/python/ops/distributed_embedding_variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ class DistributedVariableWrapper(EmbeddingWeights,
1010
def __init__(self, strategy, values, aggregation, var_policy=None):
1111
super(DistributedVariableWrapper, self).__init__(strategy, values,
1212
aggregation, var_policy)
13-
self.shadow = self._get_on_device_or_primary()
13+
self._shadow = self._get_on_device_or_primary()
1414

1515
def verify_embedding_weights(self, sparse_ids, sparse_weights=None):
16-
EmbeddingWeights.verify_embedding_param_weights(self.shadow.params,
16+
EmbeddingWeights.verify_embedding_param_weights(self._shadow.params,
1717
sparse_ids, sparse_weights)
1818

1919
def embedding_lookup(self,

0 commit comments

Comments
 (0)