Skip to content

Commit 9109fa8

Browse files
committed
ut for ps
1 parent b3bc3d4 commit 9109fa8

File tree

8 files changed

+252
-15
lines changed

8 files changed

+252
-15
lines changed

demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
except:
1111
from tensorflow.keras.optimizers import Adam
1212

13+
from tensorflow.python.distribute import parameter_server_strategy_v2
14+
1315
flags = tf.compat.v1.app.flags
1416
FLAGS = flags.FLAGS
1517
flags.DEFINE_string(
@@ -34,6 +36,18 @@
3436
], dtype=tf.int64, name='movie_id')
3537
}
3638

39+
gpus = tf.config.list_physical_devices('GPU')
40+
if gpus:
41+
try:
42+
# Currently, memory growth needs to be the same across GPUs
43+
for gpu in gpus:
44+
tf.config.experimental.set_memory_growth(gpu, True)
45+
logical_gpus = tf.config.list_logical_devices('GPU')
46+
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
47+
except RuntimeError as e:
48+
# Memory growth must be set before GPUs have been initialized
49+
print(e)
50+
3751

3852
class DualChannelsDeepModel(tf.keras.Model):
3953

@@ -91,15 +105,17 @@ def __init__(self,
91105
def call(self, features):
92106
user_id = tf.reshape(features['user_id'], (-1, 1))
93107
movie_id = tf.reshape(features['movie_id'], (-1, 1))
94-
user_latent = self.user_embedding(user_id)
95-
movie_latent = self.movie_embedding(movie_id)
96-
latent = tf.concat([user_latent, movie_latent], axis=1)
108+
printop = tf.print("partition_x4_index_key_outside ", tf.shape(user_id),user_id, tf.shape(movie_id), movie_id, output_stream=tf.compat.v1.logging.error)
109+
with tf.control_dependencies([printop]):
110+
user_latent = self.user_embedding(user_id)
111+
# movie_latent = self.movie_embedding(movie_id)
112+
# latent = tf.concat([user_latent, movie_latent], axis=1)
97113

98-
x = self.dnn1(latent)
114+
x = self.dnn1(user_latent)
99115
x = self.dnn2(x)
100116
x = self.dnn3(x)
101117

102-
bias = self.bias_net(latent)
118+
bias = self.bias_net(user_latent)
103119
x = 0.2 * x + 0.8 * bias
104120
return x
105121

@@ -136,11 +152,11 @@ def get_dataset(self, batch_size=1):
136152
dataset = dataset.shuffle(4096, reshuffle_each_iteration=False)
137153
if batch_size > 1:
138154
dataset = dataset.batch(batch_size)
139-
return dataset
155+
return dataset #.repeat()
140156

141157
def train(self):
142158
dataset = self.get_dataset(batch_size=self.train_bs)
143-
dataset = self.strategy.experimental_distribute_dataset(dataset)
159+
#dataset = self.strategy.experimental_distribute_dataset(dataset)
144160
with self.strategy.scope():
145161
model = DualChannelsDeepModel(
146162
self.ps_devices, self.embedding_size, self.embedding_size,
@@ -237,7 +253,7 @@ def start_chief(config):
237253
cluster_spec = tf.train.ClusterSpec(config["cluster"])
238254
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
239255
cluster_spec, task_type="chief", task_id=0)
240-
strategy = tf.distribute.experimental.ParameterServerStrategy(
256+
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
241257
cluster_resolver)
242258
runner = Runner(strategy=strategy,
243259
train_bs=64,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
rm -rf ./ckpt
3+
sh stop.sh
4+
sleep 1
5+
python movielens-1m-keras-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="ps" --task_id=0 &
6+
sleep 1
7+
python movielens-1m-keras-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="ps" --task_id=1 &
8+
sleep 1
9+
python movielens-1m-keras-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="worker" --task_id=0 &
10+
sleep 1
11+
sleep 1
12+
python movielens-1m-keras-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="chief" --task_id=0
13+
echo "ok"

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,13 @@ def __init__(self,
226226
if distribute_ctx.has_strategy():
227227
self.distribute_strategy = distribute_ctx.get_strategy()
228228
if self.distribute_strategy:
229-
strategy_devices = self.distribute_strategy.extended.worker_devices
229+
# l = ["/job:ps/replica:0/task:0/device:CPU:0", "/job:ps/replica:0/task:1/device:CPU:0"]
230+
l = ["/job:ps/replica:0/task:0/device:CPU:0"]
231+
strategy_devices = l #self.distribute_strategy.extended.worker_devices
230232
self.shadow_impl = tf_utils.ListWrapper([])
231233
for i, strategy_device in enumerate(strategy_devices):
232234
with ops.device(strategy_device):
235+
print(f"strategy_device {strategy_device}" )
233236
shadow_name_replica = shadow_name
234237
if i > 0:
235238
shadow_name_replica = "%s/replica_%d" % (shadow_name, i)
@@ -289,6 +292,7 @@ def call(self, ids):
289292
Returns:
290293
A embedding output with shape (shape(ids), embedding_size).
291294
"""
295+
292296
return de.shadow_ops.embedding_lookup_unique(self.shadow, ids,
293297
self.embedding_size,
294298
self.with_unique, self.name)

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ py_test(
2222
],
2323
)
2424

25+
py_test(
26+
name = "ps_test",
27+
srcs = ["ps_test.py"],
28+
python_version = "PY3",
29+
srcs_version = "PY3",
30+
deps = [
31+
"//tensorflow_recommenders_addons",
32+
],
33+
)
34+
2535
py_test(
2636
name = "dynamic_embedding_ops_test",
2737
srcs = ["dynamic_embedding_ops_test.py"],
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
2+
import os
3+
import sys
4+
from tensorflow.python.distribute import multi_process_lib
5+
import multiprocessing
6+
import tensorflow as tf
7+
import contextlib
8+
import functools
9+
from tensorflow_recommenders_addons import dynamic_embedding as de
10+
11+
from absl.testing import parameterized
12+
import numpy as np
13+
from tensorflow.core.protobuf import saved_model_pb2
14+
from tensorflow.python.checkpoint import checkpoint as tracking_util
15+
from tensorflow.python.compat import v2_compat
16+
from tensorflow.python.data.ops import dataset_ops
17+
from tensorflow.python.distribute import distribute_lib
18+
from tensorflow.python.distribute import multi_process_runner
19+
from tensorflow.python.distribute import multi_worker_test_base
20+
from tensorflow.python.distribute import parameter_server_strategy_v2
21+
from tensorflow.python.distribute import ps_values
22+
from tensorflow.python.distribute import sharded_variable
23+
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
24+
from tensorflow.python.eager import context
25+
from tensorflow.python.eager import def_function
26+
from tensorflow.python.eager import test
27+
from tensorflow.python.framework import constant_op
28+
from tensorflow.python.framework import dtypes
29+
from tensorflow.python.framework import ops
30+
from tensorflow.python.framework import tensor_spec
31+
from tensorflow.python.framework import test_util
32+
from tensorflow.python.module import module
33+
from tensorflow.python.ops import array_ops
34+
from tensorflow.python.ops import embedding_ops
35+
from tensorflow.python.ops import init_ops_v2
36+
from tensorflow.python.ops import linalg_ops_impl
37+
from tensorflow.python.ops import math_ops
38+
from tensorflow.python.ops import variable_scope
39+
from tensorflow.python.ops import variables
40+
from tensorflow.python.platform import gfile
41+
from tensorflow.python.saved_model import save as tf_save
42+
from tensorflow.python.trackable import autotrackable
43+
from tensorflow.python.training import server_lib
44+
from packaging import version
45+
46+
if version.parse(tf.__version__) >= version.parse("2.16"):
47+
from tf_keras import layers
48+
from tf_keras import Sequential
49+
from tf_keras.optimizers import Adam
50+
else:
51+
from tensorflow.python.keras import layers
52+
from tensorflow.python.keras import Sequential
53+
from tensorflow.python.keras.optimizers import Adam
54+
55+
class ParameterServerStrategyV2Test(test.TestCase):
56+
@classmethod
57+
def setUpClass(cls):
58+
super(ParameterServerStrategyV2Test, cls).setUpClass()
59+
cls.cluster = multi_worker_test_base.create_multi_process_cluster(
60+
num_workers=2, num_ps=3, rpc_layer="grpc")
61+
cls.cluster_resolver = cls.cluster.cluster_resolver
62+
63+
@classmethod
64+
def tearDownClass(cls):
65+
super(ParameterServerStrategyV2Test, cls).tearDownClass()
66+
cls.cluster.stop()
67+
68+
def testKerasFit(self):
69+
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
70+
self.cluster_resolver)
71+
# vocab_size = 100
72+
embed_dim = 8
73+
with strategy.scope():
74+
model = Sequential([
75+
layers.Input(shape=(1,), dtype=tf.int32),
76+
de.keras.layers.Embedding(embed_dim, key_dtype=tf.int32),
77+
# layers.Embedding(input_dim=vocab_size, output_dim=embed_dim),
78+
layers.Flatten(),
79+
layers.Dense(1, activation='sigmoid')
80+
])
81+
optimizer = Adam(1E-3)
82+
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
83+
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
84+
85+
ids = np.random.randint(0, 100, size=(64*2, 1))
86+
labels = np.random.randint(0, 2, size=(64*2, 1))
87+
88+
def dataset_fn(input_context):
89+
global_batch_size = 32
90+
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
91+
dataset = tf.data.Dataset.from_tensor_slices((ids, labels))
92+
dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
93+
dataset = dataset.batch(batch_size).repeat()
94+
return dataset
95+
96+
dataset = strategy.distribute_datasets_from_function(dataset_fn)
97+
98+
history = model.fit(dataset, epochs=1, steps_per_epoch=len(ids) // 64)
99+
self.assertIn('loss', history.history)
100+
101+
# def testSparselyReadForEmbeddingLookup(self):
102+
# strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
103+
# self.cluster_resolver)
104+
#
105+
# class FakeModel(module.Module):
106+
#
107+
# def __init__(self):
108+
# self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0])
109+
# self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0])
110+
#
111+
# @def_function.function(input_signature=[
112+
# tensor_spec.TensorSpec(shape=[2], dtype=dtypes.int32, name="inputs")
113+
# ])
114+
# def func(self, x):
115+
# return embedding_ops.embedding_lookup([self._var0, self._var1], x)
116+
#
117+
# with strategy.scope():
118+
# model = FakeModel()
119+
#
120+
# # Assert that ResourceGather op exists instead of Gather in training function.
121+
# found_resource_gather = False
122+
# found_gather = False
123+
#
124+
# for n in model.func.get_concrete_function().graph.as_graph_def().node:
125+
# if n.op == "ResourceGather":
126+
# found_resource_gather = True
127+
# elif n.op == "Gather":
128+
# found_gather = True
129+
# self.assertTrue(found_resource_gather)
130+
# self.assertFalse(found_gather)
131+
#
132+
# # Assert that ResourceGather op exists instead of Gather in saved_model.
133+
# found_resource_gather = False
134+
# found_gather = False
135+
#
136+
# tmp_dir = self.get_temp_dir()
137+
# tf_save.save(model, tmp_dir, signatures=model.func)
138+
#
139+
# with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f:
140+
# saved_model_proto = saved_model_pb2.SavedModel().FromString(f.read())
141+
#
142+
# for function in saved_model_proto.meta_graphs[0].graph_def.library.function:
143+
# for n in function.node_def:
144+
# if n.op == "ResourceGather":
145+
# found_resource_gather = True
146+
# resource_gather_device = n.device
147+
# elif n.op == "Gather":
148+
# found_gather = True
149+
# self.assertTrue(found_resource_gather)
150+
# self.assertFalse(found_gather)
151+
#
152+
# # We also assert that the colocate_with in embedding_ops will not result in
153+
# # a hard-coded device string.
154+
# self.assertEmpty(resource_gather_device)
155+
156+
def custom_set_spawn_exe_path():
157+
print(f"custom_set_spawn_exe_path {sys.argv[0]} {os.environ['TEST_TARGET']}")
158+
if sys.argv[0].endswith('.py'):
159+
def guess_path(package_root):
160+
# If all we have is a python module path, we'll need to make a guess for
161+
# the actual executable path.
162+
if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
163+
package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
164+
binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
165+
print(f"package_root_base {package_root_base} binary {binary}")
166+
possible_path = os.path.join(package_root_base, package_root,
167+
binary)
168+
print('Guessed test binary path: %s', possible_path)
169+
if os.access(possible_path, os.X_OK):
170+
return possible_path
171+
return None
172+
path = guess_path('tf_recommenders_addons')
173+
if path is None:
174+
print(
175+
'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
176+
sys.argv[0], os.environ)
177+
raise RuntimeError('Cannot determine binary path')
178+
sys.argv[0] = path
179+
# Note that this sets the executable for *all* contexts.
180+
multiprocessing.get_context().set_executable(sys.argv[0])
181+
182+
183+
if __name__ == "__main__":
184+
multi_process_lib._set_spawn_exe_path = custom_set_spawn_exe_path
185+
v2_compat.enable_v2_behavior()
186+
multi_process_runner.test_main()

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def compute_gradients_horovod_wrapper_impl(*args, **kwargs):
870870
def create_slots(variable, init, slot_name, op_name, bp_v2):
871871
"""Helper function for creating a slot variable for statefull optimizers."""
872872
if distribute_utils.is_distributed_variable(variable):
873-
strategy_devices = variable.distribute_strategy.extended.worker_devices
873+
strategy_devices =["/job:ps/replica:0/task:0/device:CPU:0"] #variable.distribute_strategy.extended.worker_devices
874874
primary = variable._get_on_device_or_primary()
875875
params_var_ = primary.params
876876
else:

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -952,8 +952,10 @@ def lookup(self, keys, return_exists=False, name=None):
952952
Only provided if `return_exists` is True.
953953
"""
954954
partition_index = self.partition_fn(keys, self.shard_num)
955-
keys_partitions, keys_indices = make_partition(keys, partition_index,
956-
self.shard_num)
955+
printop = tf.print("partition_x4_index_key: ", tf.shape(keys), keys,output_stream=tf.compat.v1.logging.error)
956+
with tf.control_dependencies([printop]):
957+
keys_partitions, keys_indices = make_partition(keys, partition_index,
958+
self.shard_num)
957959

958960
_values = []
959961
_exists = []
@@ -983,7 +985,9 @@ def lookup(self, keys, return_exists=False, name=None):
983985
_stitch(_exists, keys_indices, use_fast=True))
984986
else:
985987
result = _stitch(_values, keys_indices, use_fast=True)
986-
return result
988+
printop2 = tf.print("partition_x4_index_key_result: ", tf.shape(keys) ,keys, tf.shape(result), result, output_stream=tf.compat.v1.logging.error)
989+
with tf.control_dependencies([printop2]):
990+
return result
987991

988992
def export(self, name=None):
989993
"""Returns tensors of all keys and values in the table.

tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,12 @@ def embedding_lookup_unique_base(ids,
313313
ids_flat = tf.reshape(ids, (-1,))
314314
if with_unique:
315315
with ops.name_scope(name, "EmbeddingWithUnique"):
316-
unique_ids, idx = tf.unique(ids_flat)
317-
unique_embeddings = lookup_function(unique_ids)
316+
printop = tf.print("partition_x4_index_key_before_unique: ", tf.shape(ids_flat), ids_flat,output_stream=tf.compat.v1.logging.error)
317+
with tf.control_dependencies([printop]):
318+
unique_ids, idx = tf.unique(ids_flat)
319+
printop = tf.print("partition_x4_index_key_after_unique: ", tf.shape(unique_ids), unique_ids,output_stream=tf.compat.v1.logging.error)
320+
with tf.control_dependencies([printop]):
321+
unique_embeddings = lookup_function(unique_ids)
318322
embeddings_flat = tf.gather(unique_embeddings, idx)
319323
else:
320324
embeddings_flat = lookup_function(ids_flat)

0 commit comments

Comments
 (0)