Skip to content

Commit 764091c

Browse files
No public description
PiperOrigin-RevId: 675335745
1 parent bc113f0 commit 764091c

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

official/nlp/modeling/layers/transformer_encoder_block.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def __init__(self,
114114
tgt_block_size=None,
115115
use_sigmoid_attn=False,
116116
sigmoid_attn_bias=None,
117+
linformer_dim=None,
117118
**kwargs):
118119
"""Initializes `TransformerEncoderBlock`.
119120
@@ -191,6 +192,8 @@ def __init__(self,
191192
`block_sparse_attention.MultiHeadAttention`
192193
sigmoid_attn_bias: This param is only used in
193194
`block_sparse_attention.MultiHeadAttention`
195+
linformer_dim: Applies low-rank factorization on keys/values as in
196+
https://arxiv.org/pdf/2006.04768.
194197
**kwargs: keyword arguments.
195198
"""
196199
util.filter_kwargs(kwargs)
@@ -230,6 +233,7 @@ def __init__(self,
230233
self._tgt_block_size = tgt_block_size
231234
self._use_sigmoid_attn = use_sigmoid_attn
232235
self._sigmoid_attn_bias = sigmoid_attn_bias
236+
self._linformer_dim = linformer_dim
233237
if self._num_kv_heads is not None and self._src_block_size is not None:
234238
raise ValueError(
235239
"Block sparse attention does not support Multi-query attention."
@@ -366,16 +370,31 @@ def build(self, input_shape):
366370
name="output",
367371
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
368372
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
369-
**common_kwargs)
373+
**common_kwargs,
374+
)
370375
self._output_dropout = tf_keras.layers.Dropout(
371-
rate=self._output_dropout_rate)
376+
rate=self._output_dropout_rate
377+
)
372378
# Use float32 in layernorm for numeric stability.
373379
self._output_layer_norm = tf_keras.layers.LayerNormalization(
374380
name="output_layer_norm",
375381
axis=-1,
376382
epsilon=self._norm_epsilon,
377-
dtype=tf.float32)
378-
383+
dtype=tf.float32,
384+
)
385+
if self._linformer_dim is not None:
386+
# Current implementation uses the same weights for keys and values.
387+
# TODO(akandoor): Explore using different weights for keys and values.
388+
self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
389+
"...bc,cd->...bd",
390+
output_shape=(None, self._linformer_dim),
391+
kernel_initializer=tf_utils.clone_initializer(
392+
self._kernel_initializer
393+
),
394+
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
395+
name="lowrank_kv_projection",
396+
**common_kwargs,
397+
)
379398
super().build(input_shape)
380399

381400
def get_config(self):
@@ -480,6 +499,19 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
480499
if key_value is None:
481500
key_value = input_tensor
482501

502+
if self._linformer_dim is not None:
503+
if attention_mask is not None:
504+
# Applying mask before the low rank factorization so that padding is
505+
# accounted for.
506+
query_mask = tf.cast(attention_mask[:, :, 0], dtype=target_tensor.dtype)
507+
target_tensor = target_tensor * tf.expand_dims(query_mask, axis=-1)
508+
key_mask = tf.cast(attention_mask[:, 0, :], dtype=target_tensor.dtype)
509+
key_value = key_value * tf.expand_dims(key_mask, axis=-1)
510+
attention_mask = None
511+
key_value = tf.transpose(key_value, [0, 2, 1])
512+
key_value = self._lowrank_kv_projection(key_value)
513+
key_value = tf.transpose(key_value, [0, 2, 1])
514+
483515
if self._return_attention_scores:
484516
attention_output, attention_scores = self._attention_layer(
485517
query=target_tensor,

official/nlp/modeling/layers/transformer_encoder_block_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,42 @@ def test_block_sparse_attention(self, use_sigmoid_attn):
800800
output_tensor[1].shape.as_list(), expected_attention_scores_shape
801801
)
802802

803+
def test_low_rank_attention(self):
804+
num_attention_heads = 8
805+
sequence_length = 21
806+
linformer_dim = 7
807+
width = 80
808+
809+
test_layer = TransformerEncoderBlock(
810+
num_attention_heads=num_attention_heads,
811+
inner_dim=2048,
812+
inner_activation='relu',
813+
return_attention_scores=True,
814+
linformer_dim=linformer_dim,
815+
)
816+
# Create a 3-dimensional input (the first dimension is implicit).
817+
data_tensor = tf_keras.Input(shape=(sequence_length, width))
818+
output_tensor = test_layer(data_tensor)
819+
820+
expected_layer_output_shape = [None, sequence_length, width]
821+
expected_attention_scores_shape = [
822+
None,
823+
num_attention_heads,
824+
sequence_length,
825+
linformer_dim,
826+
]
827+
828+
self.assertIsInstance(output_tensor, tuple)
829+
self.assertLen(output_tensor, 2)
830+
# First is the standard output.
831+
self.assertEqual(
832+
output_tensor[0].shape.as_list(), expected_layer_output_shape
833+
)
834+
# Second is the attention scores.
835+
self.assertEqual(
836+
output_tensor[1].shape.as_list(), expected_attention_scores_shape
837+
)
838+
803839

804840
if __name__ == '__main__':
805841
tf.test.main()

0 commit comments

Comments
 (0)