Skip to content

Commit 81fb5b0

Browse files
Internal change
PiperOrigin-RevId: 463764367
1 parent 8a521b8 commit 81fb5b0

File tree

2 files changed

+74
-57
lines changed

2 files changed

+74
-57
lines changed

official/nlp/modeling/layers/kernel_attention.py

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,19 @@ def call(self, inputs, mask):
4141
return mask
4242

4343

44-
def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
44+
def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
4545
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
4646
4747
Args:
4848
tensor: Input tensor to pad.
4949
axis: Axis to pad along.
5050
chunk_length: The output tensor will have shape[axis] divisible by
5151
chunk_length.
52-
pad: Pad the input tensor across the axis from left if pad="left", right if
53-
pad="right", or apply no padding if pad=None. In the latter case, the axis
54-
dimension of the input tensor must be divisible by the chunk_length.
52+
padding: Pad the input tensor across the axis from either left or
53+
right if padding is set to "left" or "right"; applies no padding
54+
if padding is set to None. In the latter case, the axis
55+
dimension of the input tensor must be divisible by the
56+
chunk_length.
5557
5658
Returns:
5759
Padded tensor with shape[axis] divisible by chunk_length.
@@ -62,19 +64,23 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
6264
axis += rank
6365
axis_length = shape[axis]
6466
pad_length = -axis_length % chunk_length
65-
if pad == "right":
66-
pad_width_2 = [[0, pad_length]]
67-
elif pad == "left":
68-
pad_width_2 = [[pad_length, 0]]
69-
else:
67+
if padding == "right":
68+
axis_paddings = [[0, pad_length]]
69+
elif padding == "left":
70+
axis_paddings = [[pad_length, 0]]
71+
elif padding is None:
7072
if pad_length != 0:
71-
raise ValueError("When padding is not set, the axis dimension"
73+
raise ValueError("When padding is None, the axis dimension"
7274
"has to be divisible by the chunk_length.")
7375
return tensor
74-
pad_width = tf.concat(
75-
[tf.zeros([axis, 2], dtype=tf.int32), pad_width_2,
76+
else:
77+
raise ValueError("Illegal padding value; must be one of \"left\""
78+
"\"right\" or None.")
79+
paddings = tf.concat(
80+
[tf.zeros([axis, 2], dtype=tf.int32),
81+
axis_paddings,
7682
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0)
77-
return tf.pad(tensor, pad_width)
83+
return tf.pad(tensor, paddings)
7884

7985

8086
def split_tensor_into_chunks(tensor, axis, chunk_length):
@@ -95,12 +101,12 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
95101
return tf.reshape(tensor, new_shape)
96102

97103

98-
def windowed_causal_performer_attention(query_matrix,
104+
def causal_windowed_performer_attention(query_matrix,
99105
key_matrix,
100106
value_matrix,
101107
chunk_length,
102108
window_length,
103-
pad="right"):
109+
padding=None):
104110
"""Applies windowed causal kernel attention with query, key, value tensors.
105111
106112
We partition the T-length input sequence into N chunks, each of chunk_length
@@ -113,40 +119,40 @@ def windowed_causal_performer_attention(query_matrix,
113119
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
114120
attention is computed between the pair while 0 indicates attention is not
115121
computed between the pairs:
116-
111000000
117-
111000000
118-
111000000
119-
111111000
120-
111111000
121-
111111000
122-
000111111
123-
000111111
124-
000111111
122+
111000000
123+
111000000
124+
111000000
125+
111111000
126+
111111000
127+
111111000
128+
000111111
129+
000111111
130+
000111111
125131
126132
User can ensure sequence_length is divisible by chunk_length or use
127-
pad="left"/"right" to pad the sequence length either at the top or bottom
128-
respectively and make it divisible by chunk_length.
133+
padding="left"/"right" to pad the sequence length either at the left
134+
or right respectively and make it divisible by chunk_length.
129135
130136
Args:
131137
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
132138
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
133139
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
134140
chunk_length: Length of each chunk in tokens.
135141
window_length: Length of attention window in chunks.
136-
pad: Pad the query, value and key input tensors across the T dimension from
137-
left if pad="left", right if pad="right", or apply no padding if pad=None.
138-
In the latter case, the T dimension of the input tensors must be divisible
139-
by the chunk_length.
142+
padding: Pad the query, value and key input tensors across the
143+
axis from either left or right if padding is set to "left" or
144+
"right"; apply no padding if padding is set to None. In the
145+
latter case, the axis dimension of the query, value and key
146+
input tensors must be divisible by the chunk_length.
140147
141148
Returns:
142149
Window causal performer attention of shape `[B, T, N, out_dim]`.
143150
"""
144-
145151
old_shape = tf.shape(value_matrix)
146152

147-
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, pad)
148-
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, pad)
149-
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, pad)
153+
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding)
154+
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, padding)
155+
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, padding)
150156

151157
new_shape = tf.shape(value_matrix)
152158
chunked_query_matrix = split_tensor_into_chunks(
@@ -446,16 +452,17 @@ def __init__(self,
446452
begin_kernel=0,
447453
scale=None,
448454
scale_by_length=False,
449-
use_windowed_causal=False,
450-
chunk_length=1,
451-
window_length=3,
455+
use_causal_windowed=False,
456+
causal_chunk_length=1,
457+
causal_window_length=1,
458+
causal_padding=None,
452459
**kwargs):
453460
r"""Constructor of KernelAttention.
454461
455462
Args:
456-
feature_transform: A non-linear transform of the keys and quries. Possible
457-
transforms are "elu", "relu", "square", "exp", "expplus", "expmod",
458-
"identity".
463+
feature_transform: A non-linear transform of the keys and queries.
464+
Possible transforms are "elu", "relu", "square", "exp", "expplus",
465+
"expmod", "identity".
459466
num_random_features: Number of random features to be used for projection.
460467
if num_random_features <= 0, no production is used before transform.
461468
seed: The seed to begin drawing random features. Once the seed is set, the
@@ -475,11 +482,17 @@ def __init__(self,
475482
the dot product based on key length. Set as log_512^(n) to stablize
476483
attention entropy against length. Refer to
477484
https://kexue.fm/archives/8823 for details.
478-
use_windowed_causal: If true perform windowed causal attention. See
479-
windowed_causal_performer_attention function docstring for more details.
480-
chunk_length: Length of each chunk in tokens.
481-
window_length: Length of attention window in chunks.
482-
**kwargs: The same arguments `MultiHeadAttention` layer.
485+
use_causal_windowed: If true perform windowed causal attention. See
486+
causal_windowed_performer_attention function docstring for more details.
487+
causal_chunk_length: Length of each chunk in tokens.
488+
causal_window_length: Length of attention window in chunks.
489+
causal_padding: Pad the query, value and key input tensors
490+
across the axis from either left or right if padding is set to
491+
"left" or "right"; apply no padding if padding is set to None.
492+
In the latter case, the axis dimension of the query, value and
493+
key input tensors must be divisible by the chunk_length.
494+
**kwargs:
495+
The same arguments `MultiHeadAttention` layer.
483496
"""
484497
if (feature_transform not in _TRANSFORM_MAP and
485498
feature_transform != "expplus"):
@@ -509,12 +522,13 @@ def __init__(self,
509522
self._projection_matrix = create_projection_matrix(
510523
self._num_random_features, self._key_dim,
511524
tf.constant([self._seed, self._seed + 1]))
512-
self.use_windowed_causal = use_windowed_causal
513-
self.chunk_length = chunk_length
514-
self.window_length = window_length
515-
if self.use_windowed_causal and self._is_short_seq:
525+
self.use_causal_windowed = use_causal_windowed
526+
self.causal_chunk_length = causal_chunk_length
527+
self.causal_window_length = causal_window_length
528+
self.causal_padding = causal_padding
529+
if self.use_causal_windowed and self._is_short_seq:
516530
raise ValueError(
517-
"use_windowed_causal and short_seq methods are mutually exclusive")
531+
"use_causal_windowed and short_seq methods are mutually exclusive")
518532

519533
def _compute_attention(self,
520534
query,
@@ -590,9 +604,12 @@ def _compute_attention(self,
590604
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
591605
attention_scores = tf.nn.softmax(attention_scores, axis=2)
592606
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
593-
elif self.use_windowed_causal:
594-
attention_output = windowed_causal_performer_attention(
595-
query_prime, key_prime, value, self.chunk_length, self.window_length)
607+
elif self.use_causal_windowed:
608+
attention_output = causal_windowed_performer_attention(
609+
query_prime, key_prime, value,
610+
chunk_length=self.causal_chunk_length,
611+
window_length=self.causal_window_length,
612+
padding=self.causal_padding)
596613
else:
597614
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
598615
denominator = 1.0 / (

official/nlp/modeling/layers/kernel_attention_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_attention_projection(
6363
@parameterized.parameters(
6464
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
6565
[0]))
66-
def test_windowed_causal_attention_projection(
66+
def test_causal_windowed_attention_projection(
6767
self, feature_transform, num_random_features, training, redraw,
6868
begin_kernel):
6969
num_heads = 12
@@ -78,9 +78,9 @@ def test_windowed_causal_attention_projection(
7878
redraw=redraw,
7979
is_short_seq=False,
8080
begin_kernel=begin_kernel,
81-
use_windowed_causal=True,
82-
chunk_length=8,
83-
window_length=3)
81+
use_causal_windowed=True,
82+
causal_chunk_length=8,
83+
causal_window_length=3)
8484
query = tf.random.normal(
8585
shape=(batch_size, seq_length, key_dim))
8686
value = query

0 commit comments

Comments
 (0)