Skip to content

Commit 270ed2a

Browse files
Internal change
PiperOrigin-RevId: 519857140
1 parent 482ec55 commit 270ed2a

File tree

4 files changed

+230
-43
lines changed

4 files changed

+230
-43
lines changed

official/vision/modeling/layers/nn_blocks.py

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,15 +1555,32 @@ def call(self, inputs, inputs_positions=None):
15551555
class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
15561556
"""TransformerEncoderBlock layer with stochastic depth and layerscale."""
15571557

1558-
def __init__(self,
1559-
*args,
1560-
stochastic_depth_drop_rate=0.0,
1561-
layer_scale_init_value=0.0,
1562-
**kwargs):
1563-
"""Initializes TransformerEncoderBlock."""
1558+
def __init__(
1559+
self,
1560+
*args,
1561+
stochastic_depth_drop_rate=0.0,
1562+
layer_scale_init_value=0.0,
1563+
max_attention_inference_parallelism=None,
1564+
**kwargs
1565+
):
1566+
"""Initializes TransformerEncoderBlock.
1567+
1568+
Args:
1569+
*args: positional arguments passed to super().__init__.
1570+
stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
1571+
layer_scale_init_value:
1572+
max_attention_inference_parallelism: the number of examples to run in
1573+
parallel in the attention blocks during inference. Set this limit to
1574+
reduce the peak memory usage. If None, use vectorized operations to run
1575+
the whole batch in parallel.
1576+
**kwargs: keyword arguments passed to super().__init__.
1577+
"""
15641578
super().__init__(*args, **kwargs)
15651579
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
15661580
self._layer_scale_init_value = layer_scale_init_value
1581+
self._max_attention_inference_parallelism = (
1582+
max_attention_inference_parallelism
1583+
)
15671584

15681585
def build(self, input_shape):
15691586
if self._stochastic_depth_drop_rate:
@@ -1582,10 +1599,25 @@ def build(self, input_shape):
15821599
self._layer_scale_mlp = lambda x, *args, **kwargs: tf.identity(x)
15831600
super().build(input_shape)
15841601

1602+
if self._max_attention_inference_parallelism is not None:
1603+
attention_layer_config = self._attention_layer.get_config()
1604+
self._attention_layer = nn_layers.MultiHeadAttention.from_config({
1605+
**attention_layer_config,
1606+
'max_inference_parallelism': (
1607+
self._max_attention_inference_parallelism
1608+
),
1609+
})
1610+
15851611
def get_config(self):
1586-
config = {'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate}
1587-
base_config = super().get_config()
1588-
return dict(list(base_config.items()) + list(config.items()))
1612+
config = super().get_config()
1613+
config.update({
1614+
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
1615+
'layer_scale_init_value': self._layer_scale_init_value,
1616+
'max_attention_inference_parallelism': (
1617+
self._max_attention_inference_parallelism
1618+
),
1619+
})
1620+
return config
15891621

15901622
def call(self, inputs, output_range=None, training=None):
15911623
"""Transformer self-attention encoder block call."""
@@ -1675,29 +1707,39 @@ def call(self, inputs, output_range=None, training=None):
16751707

16761708
@tf.keras.utils.register_keras_serializable(package='Vision')
16771709
class TransformerScaffold(nlp_modeling.layers.TransformerScaffold):
1678-
"""TransformerScaffold layer for vision applications.
1679-
1680-
This layer is a subclass of NLP TransformerScaffold:
1710+
"""TransformerScaffold layer for vision applications."""
16811711

1682-
Attributes:
1683-
stochastic_depth_drop_rate: Drop rate for the residual connections.
1684-
return_attention_scores: Optionally return the attention output.
1685-
ffn_has_residual_connection: Whether the feedforward network has internal
1686-
residual connection and layer norm. If False, the residual connection and
1687-
the layer norm op are called inside TransformerScaffold.
1688-
"""
1712+
def __init__(
1713+
self,
1714+
*args,
1715+
stochastic_depth_drop_rate: float = 0.0,
1716+
return_attention_scores: bool = False,
1717+
ffn_has_residual_connection: bool = False,
1718+
max_attention_inference_parallelism: Optional[int] = None,
1719+
**kwargs
1720+
):
1721+
"""Initializes TransformerEncoderBlock.
16891722
1690-
def __init__(self,
1691-
*args,
1692-
stochastic_depth_drop_rate: float = 0.0,
1693-
return_attention_scores: bool = False,
1694-
ffn_has_residual_connection: bool = False,
1695-
**kwargs):
1696-
"""Initializes TransformerEncoderBlock."""
1723+
Args:
1724+
*args: positional arguments passed to super().__init__.
1725+
stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
1726+
return_attention_scores: whether to return the attention output.
1727+
ffn_has_residual_connection: whether the feedforward network has internal
1728+
residual connection and layer norm. If False, the residual connection
1729+
and the layer norm op are called inside TransformerScaffold.
1730+
max_attention_inference_parallelism: the number of examples to run in
1731+
parallel in the attention blocks during inference. Set this limit to
1732+
reduce the peak memory usage. If None, use vectorized operations to run
1733+
the whole batch in parallel.
1734+
**kwargs: keyword arguments passed to super().__init__.
1735+
"""
16971736
super().__init__(*args, **kwargs)
16981737
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
16991738
self._return_attention_scores = return_attention_scores
17001739
self._ffn_has_residual_connection = ffn_has_residual_connection
1740+
self._max_attention_inference_parallelism = (
1741+
max_attention_inference_parallelism
1742+
)
17011743

17021744
def build(self, input_shape: Union[tf.TensorShape, List[int]]):
17031745
if self._stochastic_depth_drop_rate:
@@ -1708,15 +1750,26 @@ def build(self, input_shape: Union[tf.TensorShape, List[int]]):
17081750

17091751
super().build(input_shape)
17101752

1753+
if self._max_attention_inference_parallelism is not None:
1754+
attention_layer_config = self._attention_layer.get_config()
1755+
self._attention_layer = self._attention_cls.from_config({
1756+
**attention_layer_config,
1757+
'max_inference_parallelism': (
1758+
self._max_attention_inference_parallelism
1759+
),
1760+
})
1761+
17111762
def get_config(self):
1712-
config = {
1763+
config = super().get_config()
1764+
config.update({
17131765
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
17141766
'return_attention_scores': self._return_attention_scores,
1715-
'ffn_has_residual_connection': self._ffn_has_residual_connection
1716-
}
1717-
base_config = super().get_config()
1718-
base_config.update(config)
1719-
return base_config
1767+
'ffn_has_residual_connection': self._ffn_has_residual_connection,
1768+
'max_attention_inference_parallelism': (
1769+
self._max_attention_inference_parallelism
1770+
),
1771+
})
1772+
return config
17201773

17211774
def call(
17221775
self,

official/vision/modeling/layers/nn_blocks_test.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
from tensorflow.python.distribute import combinations
2626
from tensorflow.python.distribute import strategy_combinations
27-
from official.nlp import modeling as nlp_modeling
2827
from official.vision.modeling.layers import nn_blocks
28+
from official.vision.modeling.layers import nn_layers
2929

3030

3131
def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
@@ -392,7 +392,7 @@ def auto_fn():
392392
# boolean 'True'. We register this class as a Keras serializable so we can
393393
# test serialization below.
394394
@tf.keras.utils.register_keras_serializable(package='TestOnlyAttention')
395-
class ValidatedAttentionLayer(nlp_modeling.layers.attention.MultiHeadAttention):
395+
class ValidatedAttentionLayer(nn_layers.MultiHeadAttention):
396396

397397
def __init__(self, call_list, **kwargs):
398398
super(ValidatedAttentionLayer, self).__init__(**kwargs)
@@ -414,7 +414,7 @@ def call(
414414

415415
def get_config(self):
416416
config = super(ValidatedAttentionLayer, self).get_config()
417-
config['call_list'] = []
417+
config['call_list'] = self.list
418418
return config
419419

420420

@@ -456,29 +456,32 @@ def tearDown(self):
456456
super(TransformerLayerTest, self).tearDown()
457457
tf.keras.mixed_precision.set_global_policy('float32')
458458

459-
def test_layer_creation(self):
459+
@parameterized.parameters(None, 2)
460+
def test_layer_creation(self, max_attention_inference_parallelism):
460461
sequence_length = 21
461462
width = 80
462463

463-
call_list = []
464464
attention_layer_cfg = {
465465
'num_heads': 10,
466466
'key_dim': 8,
467-
'call_list': call_list,
467+
'call_list': []
468468
}
469469
test_layer = nn_blocks.TransformerScaffold(
470470
attention_cls=ValidatedAttentionLayer,
471471
attention_cfg=attention_layer_cfg,
472472
num_attention_heads=10,
473473
inner_dim=2048,
474-
inner_activation='relu')
474+
inner_activation='relu',
475+
max_attention_inference_parallelism=max_attention_inference_parallelism,
476+
)
475477

476478
# Create a 3-dimensional input (the first dimension is implicit).
477479
data_tensor = tf.keras.Input(shape=(sequence_length, width))
478480
output_tensor = test_layer(data_tensor)
479481
# The default output of a transformer layer should be the same as the input.
480482
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
481483

484+
call_list = test_layer._attention_layer.get_config()['call_list']
482485
# If call_list[0] exists and is True, the passed layer class was
483486
# instantiated from the given config properly.
484487
self.assertNotEmpty(call_list)
@@ -551,22 +554,23 @@ def test_layer_creation_with_mask(self):
551554
self.assertNotEmpty(call_list)
552555
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
553556

554-
def test_layer_invocation(self):
557+
@parameterized.parameters(None, 2)
558+
def test_layer_invocation(self, max_attention_inference_parallelism):
555559
sequence_length = 21
556560
width = 80
557561

558-
call_list = []
559562
attention_layer_cfg = {
560563
'num_heads': 10,
561564
'key_dim': 8,
562-
'call_list': call_list,
565+
'call_list': [],
563566
}
564567
test_layer = nn_blocks.TransformerScaffold(
565568
attention_cls=ValidatedAttentionLayer,
566569
attention_cfg=attention_layer_cfg,
567570
num_attention_heads=10,
568571
inner_dim=2048,
569-
inner_activation='relu')
572+
inner_activation='relu',
573+
max_attention_inference_parallelism=max_attention_inference_parallelism)
570574

571575
# Create a 3-dimensional input (the first dimension is implicit).
572576
data_tensor = tf.keras.Input(shape=(sequence_length, width))
@@ -581,6 +585,8 @@ def test_layer_invocation(self):
581585
input_data = 10 * np.random.random_sample(
582586
(batch_size, sequence_length, width))
583587
_ = model.predict(input_data)
588+
589+
call_list = test_layer._attention_layer.get_config()['call_list']
584590
# If call_list[0] exists and is True, the passed layer class was
585591
# instantiated from the given config properly.
586592
self.assertNotEmpty(call_list)

official/vision/modeling/layers/nn_layers.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Contains common building blocks for neural networks."""
16+
1617
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1718

1819
from absl import logging
@@ -1279,3 +1280,116 @@ def get_config(self):
12791280
}
12801281
base_config = super().get_config()
12811282
return dict(list(base_config.items()) + list(config.items()))
1283+
1284+
1285+
@tf.keras.utils.register_keras_serializable(package='Vision')
1286+
class MultiHeadAttention(tf.keras.layers.MultiHeadAttention):
1287+
"""MultiHeadAttention layer.
1288+
1289+
This is an implementation of multi-headed attention as described in the paper
1290+
"Attention is all you Need" (Vaswani et al., 2017).
1291+
"""
1292+
1293+
def __init__(
1294+
self, *args, max_inference_parallelism: Optional[int] = None, **kwargs
1295+
):
1296+
"""Initializes MultiHeadAttention.
1297+
1298+
Args:
1299+
*args: Positional arguments passed to super().__init__.
1300+
max_inference_parallelism: The number of examples to run in parallel
1301+
during inference. Set this limit to reduce the peak memory usage. If
1302+
None, use vectorized operations to run the whole batch in parallel.
1303+
**kwargs: Keyword arguments passed to super().__init__.
1304+
"""
1305+
super().__init__(*args, **kwargs)
1306+
self._max_inference_parallelism = max_inference_parallelism
1307+
1308+
def get_config(self):
1309+
config = super().get_config()
1310+
config.update({
1311+
'max_inference_parallelism': self._max_inference_parallelism,
1312+
})
1313+
return config
1314+
1315+
def _compute_attention(
1316+
self,
1317+
query: tf.Tensor,
1318+
key: tf.Tensor,
1319+
value: tf.Tensor,
1320+
attention_mask: Optional[tf.Tensor] = None,
1321+
training: Optional[bool] = None,
1322+
):
1323+
"""Applies dot-product attention with query, key, value tensors.
1324+
1325+
Args:
1326+
query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
1327+
key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
1328+
value: Projected value `Tensor` of shape `(B, S, N, value_dim)`.
1329+
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
1330+
attention to certain positions. It is generally not needed if the
1331+
`query` and `value` (and/or `key`) are masked.
1332+
training: Python boolean indicating whether the layer should behave in
1333+
training mode (adding dropout) or in inference mode (doing nothing).
1334+
1335+
Returns:
1336+
attention_output: Multi-headed outputs of attention computation.
1337+
attention_scores: Multi-headed attention weights.
1338+
"""
1339+
batch_size = query.get_shape().as_list()[0] # None if dynamic.
1340+
1341+
if (
1342+
training
1343+
or self._max_inference_parallelism is None
1344+
or self._max_inference_parallelism <= 0
1345+
or (
1346+
# If the whole batch is allowed to be run in parallel, use fully
1347+
# vectorized computation instead of tf.map_fn to make things more
1348+
# efficient.
1349+
batch_size is not None
1350+
and batch_size <= self._max_inference_parallelism
1351+
)
1352+
):
1353+
return self._compute_attention_delegate(
1354+
query, key, value, attention_mask, training
1355+
)
1356+
else:
1357+
# Sequentialize the inference execution with limited parallelism.
1358+
def _compute_fn(x):
1359+
attention_output, attention_scores = self._compute_attention_delegate(
1360+
query=x[0][tf.newaxis, ...],
1361+
key=x[1][tf.newaxis, ...],
1362+
value=x[2][tf.newaxis, ...],
1363+
attention_mask=x[3][tf.newaxis, ...] if len(x) >= 4 else None,
1364+
training=training,
1365+
)
1366+
attention_output = tf.squeeze(attention_output, axis=0)
1367+
attention_scores = tf.squeeze(attention_scores, axis=0)
1368+
return attention_output, attention_scores
1369+
1370+
if attention_mask is not None:
1371+
elems = [query, key, value, attention_mask]
1372+
else:
1373+
elems = [query, key, value]
1374+
1375+
return tf.map_fn(
1376+
fn=_compute_fn,
1377+
elems=elems,
1378+
fn_output_signature=(value.dtype, value.dtype),
1379+
parallel_iterations=self._max_inference_parallelism,
1380+
)
1381+
1382+
def _compute_attention_delegate(
1383+
self,
1384+
query: tf.Tensor,
1385+
key: tf.Tensor,
1386+
value: tf.Tensor,
1387+
attention_mask: Optional[tf.Tensor] = None,
1388+
training: Optional[bool] = None,
1389+
):
1390+
"""Implements dot-product attention with query, key, value tensors."""
1391+
# Simply calls the implementation of the super class here, while the users
1392+
# can override this function for customizing attention computation.
1393+
return super()._compute_attention(
1394+
query, key, value, attention_mask, training
1395+
)

0 commit comments

Comments
 (0)