Skip to content

Commit 6f7ed09

Browse files
Lifannrhdong
authored andcommitted
Enable directly use dynamic_embedding.embedding_lookup series API inside tf.function scope
1 parent aa2fe90 commit 6f7ed09

File tree

4 files changed

+217
-30
lines changed

4 files changed

+217
-30
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import math
2424
import numpy as np
2525
import os
26+
import tensorflow as tf
2627

2728
from tensorflow_recommenders_addons import dynamic_embedding as de
2829

@@ -47,9 +48,12 @@
4748
from tensorflow.python.ops import gen_array_ops
4849
from tensorflow.python.ops import init_ops
4950
from tensorflow.python.ops import math_ops
50-
from tensorflow.python.ops import variable_scope
51+
from tensorflow.python.ops import nn_ops
52+
from tensorflow.python.ops import nn_impl
53+
from tensorflow.python.ops import resource_variable_ops
5154
from tensorflow.python.ops import script_ops
5255
from tensorflow.python.ops import variables
56+
from tensorflow.python.ops import variable_scope
5357
from tensorflow.python.platform import test
5458
from tensorflow.python.training import device_setter
5559
from tensorflow.python.training import server_lib
@@ -1296,5 +1300,84 @@ def test_colocate_to_ids(self):
12961300
self.assertAllEqual(tw_q.device, '/job:dist/task:1')
12971301

12981302

1303+
@test_util.run_all_in_graph_and_eager_modes
1304+
class EmbeddingLookupEagerTest(test.TestCase):
1305+
1306+
def _create_input_and_params(self,
1307+
name,
1308+
batch_size=4,
1309+
nids=64,
1310+
embedding_size=1):
1311+
assert nids % batch_size == 0
1312+
ids = math_ops.range(0, nids, dtype=dtypes.int64)
1313+
ids = array_ops.reshape(ids, (batch_size, -1))
1314+
labels = array_ops.zeros((batch_size,), dtype=dtypes.float32)
1315+
devar = de.get_variable(name + '/dynamic_embedding',
1316+
dim=embedding_size,
1317+
initializer=tf.keras.initializers.Zeros())
1318+
tfvar = tf.Variable(tf.keras.initializers.Zeros()((nids, embedding_size),
1319+
dtype=tf.float32))
1320+
return ids, labels, devar, tfvar
1321+
1322+
def _loss_fn(self, params, ids, labels):
1323+
1324+
if isinstance(params, de.Variable):
1325+
embedding = de.embedding_lookup(params, ids)
1326+
elif isinstance(
1327+
params, (resource_variable_ops.ResourceVariable, variables.Variable)):
1328+
embedding = embedding_ops.embedding_lookup(params, ids)
1329+
else:
1330+
raise TypeError
1331+
1332+
logits = math_ops.reduce_mean(math_ops.reduce_sum(embedding, 1), 1)
1333+
entropy = nn_impl.sigmoid_cross_entropy_with_logits(logits=logits,
1334+
labels=labels)
1335+
loss = math_ops.reduce_mean(entropy)
1336+
return loss
1337+
1338+
def test_run_training_eagerly(self):
1339+
if not context.executing_eagerly():
1340+
self.skipTest('Only test functional API in eager mode.')
1341+
1342+
batch_size = 4
1343+
ids, labels, devar, tfvar = self._create_input_and_params('vns079',
1344+
embedding_size=1)
1345+
nsteps = 10
1346+
1347+
loss_fn = tf.function()(self._loss_fn)
1348+
1349+
def sorted_dynamic_embedding_value():
1350+
embedding_var = devar
1351+
optimizer = tf.keras.optimizers.Adam(1E-3)
1352+
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
1353+
1354+
def var_fn():
1355+
return list(embedding_var.trainable_store.values())
1356+
1357+
for _ in range(nsteps):
1358+
optimizer.minimize(lambda: loss_fn(embedding_var, ids, labels), var_fn)
1359+
1360+
keys, values = embedding_var.export()
1361+
order = tf.argsort(keys)
1362+
return array_ops.gather(values, order)
1363+
1364+
def sorted_static_embedding_value():
1365+
embedding_var = tfvar
1366+
optimizer = tf.keras.optimizers.Adam(1E-3)
1367+
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
1368+
1369+
def var_fn():
1370+
return [embedding_var]
1371+
1372+
for _ in range(nsteps):
1373+
optimizer.minimize(lambda: loss_fn(embedding_var, ids, labels), var_fn)
1374+
1375+
return embedding_var.read_value()
1376+
1377+
de_values = sorted_dynamic_embedding_value()
1378+
tf_values = sorted_static_embedding_value()
1379+
self.assertAllClose(de_values, tf_values)
1380+
1381+
12991382
if __name__ == "__main__":
13001383
test.main()

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from tensorflow.python.util import compat
4848
from tensorflow.python.util.tf_export import tf_export
4949

50+
_ANONYMOUS_TRAINABLE_STORE_KEY = '_anonymous_trainable_store_key'
51+
5052

5153
class TrainableWrapper(resource_variable_ops.ResourceVariable):
5254
"""
@@ -464,12 +466,6 @@ def _reset_ids(self, ids):
464466
s._reset_ids(ids)
465467

466468

467-
# TODO (Lifann) Introduce ShadowVariable when using tf.function.
468-
# Could hack the tf.function and mark whether if the Python context
469-
# is in the `function` scope. When inside `function` scope, create
470-
# ShadowVariable out of the scope, and then do the lookup. Also
471-
# need to keep the ShadowVariable on record of the params, without
472-
# breaking the compatibility of embedding_lookup API.
473469
def embedding_lookup(
474470
params,
475471
ids,
@@ -495,14 +491,20 @@ def embedding_lookup(
495491
validate_indices: No used, just for compatible with nn.embedding_lookup .
496492
max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
497493
than this value.
498-
return_trainable: optional, If True, also return TrainableWrapper
494+
return_trainable: optional, If True, also return TrainableWrapper. If in
495+
eager mode, it will return a `ShadowVariable`, which is eager derivative of
496+
TrainableWrapper. If inside tf.function scope, then set return_trainable
497+
is disabled. Please use `dynamic_embedding.Variable.get_trainable_by_name` or
498+
`dynamic_embedding.Variable.trainable_store` to get the created trainable
499+
shadow inside tf.function scope.
499500
Returns:
500501
A tensor with shape [shape of ids] + [dim],
501502
dim is equal to the value dim of params.
502503
containing the values from the params tensor(s) for keys in ids.
503504
trainable_wrap:
504505
A TrainableWrapper object used to fill the Optimizers `var_list`
505-
Only provided if `return_trainable` is True.
506+
Only provided if `return_trainable` is True. If in eager mode,
507+
it will be a `ShadowVariable`, which is eager derivative of TrainableWrapper.
506508
"""
507509
if isinstance(params, (list, tuple)) and len(params) > 1:
508510
raise ValueError("Only one params is allowed.")
@@ -548,28 +550,58 @@ def initial_value():
548550
if params.trainable:
549551
collections += [ops.GraphKeys.TRAINABLE_VARIABLES]
550552

551-
def _create_trainable(trainable_name):
552-
return de.TrainableWrapper(params,
553-
ids,
554-
max_norm=max_norm,
555-
initial_value=initial_value,
556-
dtype=params.value_dtype,
557-
trainable=params.trainable,
558-
collections=collections,
559-
model_mode=ModelMode.CURRENT_SETTING,
560-
name=trainable_name)
553+
def _create_or_get_trainable(trainable_name):
554+
if trainable_name is None:
555+
if context.executing_eagerly():
556+
raise ValueError(
557+
'Must provide a name for embedding_lookup when using eager execution.'
558+
)
559+
trainable_name = ops.get_default_graph().unique_name(
560+
_ANONYMOUS_TRAINABLE_STORE_KEY)
561+
if not context.executing_eagerly() and not ops.inside_function():
562+
wrapper = de.TrainableWrapper(params,
563+
ids,
564+
max_norm=max_norm,
565+
initial_value=initial_value,
566+
dtype=params.value_dtype,
567+
trainable=params.trainable,
568+
collections=collections,
569+
model_mode=ModelMode.CURRENT_SETTING,
570+
name=trainable_name)
571+
params._trainable_store[trainable_name] = wrapper
572+
return wrapper
573+
else:
574+
with ops.init_scope():
575+
shadow = params._trainable_store.get(trainable_name, None)
576+
if shadow is None:
577+
shadow = de.shadow_ops.ShadowVariable(
578+
params,
579+
name=trainable_name,
580+
max_norm=max_norm,
581+
trainable=params.trainable,
582+
model_mode=ModelMode.CURRENT_SETTING)
583+
params._trainable_store[trainable_name] = shadow
584+
return shadow
561585

562586
with ops.colocate_with(ids, ignore_existing=True):
563-
if context.executing_eagerly():
564-
trainable_ = params._trainable_store.get(name, None)
565-
if trainable_ is None:
566-
trainable_ = _create_trainable(name)
567-
params._trainable_store[name] = trainable_
568-
else:
569-
trainable_._reset_ids(ids)
570-
else:
571-
trainable_ = _create_trainable(name)
572-
params._trainable_store[name] = trainable_
587+
trainable_ = _create_or_get_trainable(name)
588+
589+
if isinstance(trainable_, de.shadow_ops.ShadowVariable):
590+
embeddings = de.shadow_ops.embedding_lookup(
591+
trainable_,
592+
ids,
593+
partition_strategy=partition_strategy,
594+
name=name,
595+
validate_indices=validate_indices)
596+
if return_trainable:
597+
if not context.executing_eagerly():
598+
raise NotImplementedError(
599+
'return_trainable currently is not implemented when using tf.function.'
600+
' Please use `Variable.trainable_store` or `Variable.get_trainable_by_name`'
601+
' to access the shadow trainable variable if call `embedding_lookup` series'
602+
' APIs inside tf.function scope.')
603+
return embeddings, trainable_
604+
return embeddings
573605

574606
embeddings = array_ops.identity(trainable_)
575607
embeddings = array_ops.reshape(embeddings, shape=embeddings_shape)
@@ -738,7 +770,7 @@ def embedding_lookup_sparse(
738770
embeddings, trainable_ = embedding_lookup(
739771
params,
740772
ids,
741-
name=name + "/embedding_lookup",
773+
name=name + '/embedding_lookup',
742774
partition_strategy=partition_strategy,
743775
max_norm=max_norm,
744776
return_trainable=True,

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,43 @@ def get_slot_variables(self, optimizer):
659659
continue
660660
return slots
661661

662+
def get_trainable_by_name(self, name):
663+
"""
664+
Get trainable shadow variable when using eager execution.
665+
666+
Example:
667+
```python
668+
from tensorflow_recommenders_addons import dynamic_embedding as de
669+
init = tf.keras.initializers.RandomNormal()
670+
params = de.get_variable('foo', dim=4, initializer=init)
671+
optimizer = tf.keras.optimizers.Adam(1E-3)
672+
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
673+
674+
@tf.function
675+
def loss_fn(ids):
676+
emb = de.embedding_lookup(params, ids, name='user_embedding')
677+
emb = tf.math.reduce_sum(emb, axis=1)
678+
loss = tf.reduce_mean(emb)
679+
return loss
680+
681+
for i in range(10):
682+
optimizer.minimize(lambda: loss_fn(ids),
683+
var_list=[params.get_eager_trainable_by_name('user_embedding')])
684+
```
685+
686+
Args:
687+
name: str. Name used to get the trainable shadow to the Variable.
688+
689+
Returns:
690+
A ShadowVariable object refers to the specific name.
691+
692+
Raises:
693+
RuntimeError: if not in eager mode.
694+
"""
695+
if not isinstance(name, str):
696+
raise TypeError('name should be a string')
697+
return self._trainable_store.get(name, None)
698+
662699
def _gather_saveables_for_checkpoint(self):
663700
g = ops.get_default_graph()
664701
if context.executing_eagerly() or g._functions:
@@ -678,6 +715,10 @@ def _gather_saveables_for_checkpoint(self):
678715
saveables[saveable.keywords["name"]] = saveable
679716
return saveables
680717

718+
@property
719+
def trainable_store(self):
720+
return self._trainable_store
721+
681722

682723
@tf_export("dynamic_embedding.get_variable")
683724
def get_variable(

tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from tensorflow.python.framework import tensor_shape
4949
from tensorflow.python.ops import array_ops
5050
from tensorflow.python.ops import control_flow_ops
51+
from tensorflow.python.ops import gen_resource_variable_ops
5152
from tensorflow.python.ops import resource_variable_ops
5253
from tensorflow.python.training.tracking import base as trackable
5354

@@ -174,6 +175,36 @@ def value(self, do_prefetch=False):
174175
with ops.colocate_with(None, ignore_existing=True):
175176
return self._read_variable_op(do_prefetch=do_prefetch)
176177

178+
def assign(self, value, use_locking=None, name=None, read_value=True):
179+
"""
180+
Assigns a new value to this variable.
181+
To discriminate with ResourceVariable, the shadow always uses a
182+
variant space to hold the temporary embedding lookup buffer.
183+
184+
Args:
185+
value: A `Tensor`. The new value for this variable.
186+
use_locking: If `True`, use locking during the assignment.
187+
name: The name to use for the assignment.
188+
read_value: A `bool`. Whether to read and return the new value of the
189+
variable or not.
190+
191+
Returns:
192+
If `read_value` is `True`, this method will return the new value of the
193+
variable after the assignment has completed. Otherwise, when in graph mode
194+
it will return the `Operation` that does the assignment, and when in eager
195+
mode it will return `None`.
196+
"""
197+
# Note: not depending on the cached value here since this can be used to
198+
# initialize the variable.
199+
with resource_variable_ops._handle_graph(self.handle):
200+
value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
201+
assign_op = gen_resource_variable_ops.assign_variable_op(self.handle,
202+
value_tensor,
203+
name=name)
204+
if read_value:
205+
return self._lazy_read(assign_op)
206+
return assign_op
207+
177208
def _reset_ids(self, ids):
178209
return self.ids.assign(ids, use_locking=True)
179210

0 commit comments

Comments
 (0)