Skip to content

Commit 2d373e1

Browse files
committed
resource var
1 parent f325fbb commit 2d373e1

File tree

5 files changed

+75
-16
lines changed

5 files changed

+75
-16
lines changed

demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ def __init__(self,
7272
user_embedding_size,
7373
initializer=embedding_initializer,
7474
devices=self.devices,
75-
with_unique=False,
75+
# with_unique=False,
7676
name='user_embedding')
7777
self.movie_embedding = de.keras.layers.SquashedEmbedding(
7878
movie_embedding_size,
7979
initializer=embedding_initializer,
8080
devices=self.devices,
81-
with_unique=False,
81+
# with_unique=False,
8282
name='movie_embedding')
8383

8484
self.dnn1 = tf.keras.layers.Dense(
@@ -105,10 +105,11 @@ def __init__(self,
105105
@tf.function
106106
def call(self, features):
107107
user_id = tf.reshape(features['user_id'], (-1, 1))
108-
movie_id = tf.reshape(features['movie_id'], (-1, 1))
108+
# movie_id = tf.reshape(features['movie_id'], (-1, 1))
109109
user_latent = self.user_embedding(user_id)
110-
movie_latent = self.movie_embedding(movie_id)
111-
latent = tf.concat([user_latent, movie_latent], axis=1)
110+
# movie_latent = self.movie_embedding(movie_id)
111+
# latent = tf.concat([user_latent, movie_latent], axis=1)
112+
latent = user_latent
112113
x = self.dnn1(latent)
113114
x = self.dnn2(x)
114115
x = self.dnn3(x)
@@ -129,7 +130,7 @@ def __init__(self, strategy, train_bs, test_bs, epochs, steps_per_epoch,
129130
"/job:ps/replica:0/task:{}/device:CPU:0".format(idx)
130131
for idx in range(self.num_ps)
131132
]
132-
self.embedding_size = 32
133+
self.embedding_size = 4
133134
self.train_bs = train_bs
134135
self.test_bs = test_bs
135136
self.epochs = epochs
@@ -254,10 +255,10 @@ def start_chief(config):
254255
cluster_spec, task_type="chief", task_id=0)
255256
strategy = tf_dist.experimental.ParameterServerStrategy(cluster_resolver)
256257
runner = Runner(strategy=strategy,
257-
train_bs=64,
258+
train_bs=4,
258259
test_bs=1,
259260
epochs=1,
260-
steps_per_epoch=1000,
261+
steps_per_epoch=4,
261262
model_dir=None,
262263
export_dir=None)
263264
runner.train()

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ def call(self, ids):
303303
Returns:
304304
A embedding output with shape (shape(ids), embedding_size).
305305
"""
306+
tfprint = tf.print("ids_8a:", ids, output_stream=tf.compat.v1.logging.error)
307+
with tf.control_dependencies([tfprint]):
308+
pass
306309
return de.shadow_ops.embedding_lookup_unique(self.shadow, ids,
307310
self.embedding_size,
308311
self.with_unique, self.name)

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"""patch on optimizers"""
1717

1818
import functools
19+
20+
import tensorflow as tf
1921
from packaging import version
2022
import six
2123

@@ -184,6 +186,9 @@ def apply_grad_to_update_var(var, grad):
184186
"Cannot use a constraint function on a sparse variable.")
185187
if "apply_state" in self._sparse_apply_args:
186188
apply_kwargs["apply_state"] = apply_state
189+
# printop = tf.print("ids_8d:", output_stream=tf.compat.v1.logging.error)
190+
# with tf.control_dependencies([printop]):
191+
# pass
187192
with ops.control_dependencies(_before):
188193
_apply_op = self._resource_apply_sparse_duplicate_indices(
189194
grad.values, var, grad.indices, **apply_kwargs)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from tensorflow.python.distribute import ps_values, distribute_lib
2+
from tensorflow.python.distribute.distribute_lib import _get_per_thread_mode
3+
from tensorflow.python.distribute.parameter_server_strategy_v2 import ParameterServerStrategyV2, \
4+
ParameterServerStrategyV2Extended
5+
from tensorflow.python.ops import variables
6+
import tensorflow as tf
7+
8+
9+
class DEPerWorkerVariable(ps_values.PerWorkerVariable):
10+
def __init__(self, *args, **kwargs):
11+
super(DEPerWorkerVariable, self).__init__(*args, **kwargs)
12+
13+
def create_per_worker_de_variable(strategy, name, dtype, shape):
14+
# printop = tf.print("st_2:", strategy,
15+
# tf.distribute.get_replica_context() ,
16+
# output_stream=tf.compat.v1.logging.error)
17+
# with tf.control_dependencies([printop]):
18+
with strategy.scope():
19+
return variables.Variable(initial_value=(),
20+
shape=shape, dtype=dtype, name=name,
21+
per_worker_de_variable=True)
22+
23+
original_create_variable = ParameterServerStrategyV2Extended._create_variable
24+
25+
def patched_create_variable(self, next_creator, **kwargs):
26+
if kwargs.pop("per_worker_de_variable", False):
27+
return _create_per_worker_de_variable(self, next_creator, **kwargs)
28+
return original_create_variable(self, next_creator, **kwargs)
29+
30+
def _create_per_worker_de_variable(strategy_extended, next_creator, **kwargs):
31+
return DEPerWorkerVariable(strategy_extended._container_strategy(), next_creator, **kwargs)
32+
33+
ParameterServerStrategyV2Extended._create_variable = patched_create_variable
34+
35+
class DEParameterServerStrategy(ParameterServerStrategyV2):
36+
def __init__(self, cluster_resolver, variable_partitioner=None):
37+
super(DEParameterServerStrategy, self).__init__(cluster_resolver, variable_partitioner)

tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
import tensorflow as tf
3939

40-
from tensorflow.python.distribute import distribute_lib
40+
from tensorflow.python.distribute import distribute_lib, ps_values
4141
from tensorflow.python.eager import context
4242
from tensorflow.python.framework import dtypes
4343
from tensorflow.python.framework import ops
@@ -49,6 +49,10 @@
4949
from tensorflow_recommenders_addons import dynamic_embedding as de
5050
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.embedding_weights import EmbeddingWeights, \
5151
TrainableWrapper
52+
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.parameter_server import create_per_worker_de_variable, \
53+
DEPerWorkerVariable
54+
from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import is_parameter_server_strategy
55+
from tensorflow.python.ops import variables
5256

5357
if version.parse(tf.__version__) >= version.parse("2.10"):
5458
from tensorflow.python.trackable import base as trackable
@@ -114,15 +118,18 @@ def __init__(self,
114118
ids_name = self._name + '-ids'
115119
if ids is None:
116120
self.ids = get_de_resource_variable(
117-
trainable=False,
118121
collections=collections,
119122
name=ids_name,
120123
dtype=self.params.key_dtype,
121124
distribute_strategy=distribute_strategy,
122125
shape=tensor_shape.TensorShape(None))
123126
else:
124127
if not isinstance(ids, resource_variable_ops.ResourceVariable):
125-
raise TypeError('If ids is set, it needs to be a ResourceVariable')
128+
tfprint = tf.print("ids_8c:", ids, type(ids), ids.__class__.__name__, output_stream=tf.compat.v1.logging.error)
129+
with tf.control_dependencies([tfprint]):
130+
pass
131+
# not isinstance(ids, variables.Variable)):
132+
# raise TypeError('If ids is set, it needs to be a ResourceVariable or ps_values.PerWorkerVariable')
126133
self.ids = ids
127134

128135
model_mode = kwargs.get('model_mode', None)
@@ -152,7 +159,6 @@ def __init__(self,
152159
exists_name = self._name + '-exists'
153160
if exists is None:
154161
self.exists = get_de_resource_variable(
155-
trainable=False,
156162
collections=collections,
157163
name=exists_name,
158164
dtype=dtypes.bool,
@@ -272,10 +278,14 @@ def embedding_lookup(
272278
with ops.name_scope(name, "shadow_embedding_lookup"):
273279
with ops.colocate_with(None, ignore_existing=True):
274280
if de.ModelMode.CURRENT_SETTING == de.ModelMode.TRAIN:
281+
tfprint = tf.print("ids_8b:", shadow_.ids, ids, output_stream=tf.compat.v1.logging.error)
282+
with tf.control_dependencies([tfprint]):
283+
pass
275284
with ops.control_dependencies([shadow_._reset_ids(ids)]):
276285
result = shadow_.read_value(do_prefetch=True)
277286
else:
278287
result = shadow_.params.lookup(ids)
288+
279289
return result
280290

281291

@@ -360,14 +370,17 @@ def __init__(self, *args, **kwargs):
360370
super(DEResourceVariable, self).__init__(*args, **kwargs)
361371

362372

363-
def get_de_resource_variable(trainable,
373+
def get_de_resource_variable(
364374
collections,
365375
name,
366376
dtype,
367377
distribute_strategy,
368378
shape=tensor_shape.TensorShape(None)):
369-
return DEResourceVariable((),
370-
trainable=trainable,
379+
if is_parameter_server_strategy(distribute_strategy):
380+
return create_per_worker_de_variable(distribute_strategy, name, dtype, shape)
381+
else:
382+
return DEResourceVariable((),
383+
trainable=False,
371384
collections=collections,
372385
name=name,
373386
dtype=dtype,
@@ -377,7 +390,7 @@ def get_de_resource_variable(trainable,
377390

378391
def is_de_resource_variable(var):
379392
return isinstance(var, DEResourceVariable) or isinstance(
380-
var, TrainableWrapper)
393+
var, TrainableWrapper) or isinstance(var, DEPerWorkerVariable)
381394

382395

383396
class HvdVariable(EmbeddingWeights):

0 commit comments

Comments
 (0)