Skip to content

Commit 5fea53a

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 418724903
1 parent 77d9fd6 commit 5fea53a

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

official/modeling/tf_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,74 @@ def safe_mean(losses):
201201
total = tf.reduce_sum(losses)
202202
num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
203203
return tf.math.divide_no_nan(total, num_elements)
204+
205+
206+
def get_replica_id():
207+
"""Gets replica id depending on the environment."""
208+
context = tf.distribute.get_replica_context()
209+
if context is not None:
210+
return context.replica_id_in_sync_group
211+
else:
212+
raise RuntimeError("Unknown replica context. The `get_replica_id` method "
213+
"relies on TF 2.x tf.distribute API.")
214+
215+
216+
def cross_replica_concat(value, axis, name="cross_replica_concat"):
217+
"""Concatenates the given `value` across (GPU/TPU) cores, along `axis`.
218+
219+
In general, each core ("replica") will pass a
220+
replica-specific value as `value` (corresponding to some element of a
221+
data-parallel computation taking place across replicas).
222+
223+
The resulting concatenated `Tensor` will have the same shape as `value` for
224+
all dimensions except `axis`, where it will be larger by a factor of the
225+
number of replicas. It will also have the same `dtype` as `value`.
226+
227+
The position of a given replica's `value` within the resulting concatenation
228+
is determined by that replica's replica ID. For
229+
example:
230+
231+
With `value` for replica 0 given as
232+
233+
0 0 0
234+
0 0 0
235+
236+
and `value` for replica 1 given as
237+
238+
1 1 1
239+
1 1 1
240+
241+
the resulting concatenation along axis 0 will be
242+
243+
0 0 0
244+
0 0 0
245+
1 1 1
246+
1 1 1
247+
248+
and this result will be identical across all replicas.
249+
250+
Note that this API only works in TF2 with `tf.distribute`.
251+
252+
Args:
253+
value: The `Tensor` to concatenate across replicas. Each replica will have a
254+
different value for this `Tensor`, and these replica-specific values will
255+
be concatenated.
256+
axis: The axis along which to perform the concatenation as a Python integer
257+
(not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension.
258+
name: A name for the operation (used to create a name scope).
259+
260+
Returns:
261+
The result of concatenating `value` along `axis` across replicas.
262+
263+
Raises:
264+
RuntimeError: when the batch (0-th) dimension is None.
265+
"""
266+
with tf.name_scope(name):
267+
context = tf.distribute.get_replica_context()
268+
# Typically this could be hit only if the tensor is derived from a
269+
# dataset with finite epochs and drop_remainder=False, where the last
270+
# batch could of different batch size and then the dim-0 is of dynamic
271+
# shape.
272+
if value.shape.as_list()[0] is None:
273+
raise RuntimeError(f"{value} has unknown batch.")
274+
return context.all_gather(value, axis=axis)

official/modeling/tf_utils_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for tf_utils."""
16+
from absl.testing import parameterized
17+
import numpy as np
18+
import tensorflow as tf
19+
20+
from tensorflow.python.distribute import combinations
21+
from tensorflow.python.distribute import strategy_combinations
22+
from official.modeling import tf_utils
23+
24+
25+
def all_strategy_combinations():
26+
return combinations.combine(
27+
strategy=[
28+
strategy_combinations.cloud_tpu_strategy,
29+
strategy_combinations.mirrored_strategy_with_two_gpus,
30+
],
31+
mode='eager',
32+
)
33+
34+
35+
class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
36+
37+
@combinations.generate(all_strategy_combinations())
38+
def test_cross_replica_concat(self, strategy):
39+
num_cores = strategy.num_replicas_in_sync
40+
41+
shape = (2, 3, 4)
42+
43+
def concat(axis):
44+
45+
@tf.function
46+
def function():
47+
replica_value = tf.fill(shape, tf_utils.get_replica_id())
48+
return tf_utils.cross_replica_concat(replica_value, axis=axis)
49+
50+
return function
51+
52+
def expected(axis):
53+
values = [np.full(shape, i) for i in range(num_cores)]
54+
return np.concatenate(values, axis=axis)
55+
56+
per_replica_results = strategy.run(concat(axis=0))
57+
replica_0_result = per_replica_results.values[0].numpy()
58+
for value in per_replica_results.values[1:]:
59+
self.assertAllClose(value.numpy(), replica_0_result)
60+
self.assertAllClose(replica_0_result, expected(axis=0))
61+
62+
replica_0_result = strategy.run(concat(axis=1)).values[0].numpy()
63+
self.assertAllClose(replica_0_result, expected(axis=1))
64+
65+
replica_0_result = strategy.run(concat(axis=2)).values[0].numpy()
66+
self.assertAllClose(replica_0_result, expected(axis=2))
67+
68+
@combinations.generate(all_strategy_combinations())
69+
def test_cross_replica_concat_gradient(self, strategy):
70+
num_cores = strategy.num_replicas_in_sync
71+
72+
shape = (10, 5)
73+
74+
@tf.function
75+
def function():
76+
replica_value = tf.random.normal(shape)
77+
with tf.GradientTape() as tape:
78+
tape.watch(replica_value)
79+
concat_value = tf_utils.cross_replica_concat(replica_value, axis=0)
80+
output = tf.reduce_sum(concat_value)
81+
return tape.gradient(output, replica_value)
82+
83+
per_replica_gradients = strategy.run(function)
84+
for gradient in per_replica_gradients.values:
85+
self.assertAllClose(gradient, num_cores * tf.ones(shape))
86+
87+
88+
if __name__ == '__main__':
89+
tf.test.main()

0 commit comments

Comments
 (0)