Skip to content

Commit d1fca26

Browse files
avinava-otensorflower-gardener
authored andcommitted
windowed causal performer
PiperOrigin-RevId: 463429471
1 parent 1db7588 commit d1fca26

File tree

2 files changed

+194
-1
lines changed

2 files changed

+194
-1
lines changed

official/nlp/modeling/layers/kernel_attention.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,148 @@ def call(self, inputs, mask):
4141
return mask
4242

4343

44+
def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
45+
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
46+
47+
Args:
48+
tensor: Input tensor to pad.
49+
axis: Axis to pad along.
50+
chunk_length: The output tensor will have shape[axis] divisible by
51+
chunk_length.
52+
pad: Pad the input tensor across the axis from left if pad="left", right if
53+
pad="right", or apply no padding if pad=None. In the latter case, the axis
54+
dimension of the input tensor must be divisible by the chunk_length.
55+
56+
Returns:
57+
Padded tensor with shape[axis] divisible by chunk_length.
58+
"""
59+
shape = tf.shape(tensor)
60+
rank = tf.rank(tensor)
61+
if axis < 0:
62+
axis += rank
63+
axis_length = shape[axis]
64+
pad_length = -axis_length % chunk_length
65+
if pad == "right":
66+
pad_width_2 = [[0, pad_length]]
67+
elif pad == "left":
68+
pad_width_2 = [[pad_length, 0]]
69+
else:
70+
if pad_length != 0:
71+
raise ValueError("When padding is not set, the axis dimension"
72+
"has to be divisible by the chunk_length.")
73+
return tensor
74+
pad_width = tf.concat(
75+
[tf.zeros([axis, 2], dtype=tf.int32), pad_width_2,
76+
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0)
77+
return tf.pad(tensor, pad_width)
78+
79+
80+
def split_tensor_into_chunks(tensor, axis, chunk_length):
81+
"""Reshape tensor along given axis using chunk_length.
82+
83+
Args:
84+
tensor: Input tensor.
85+
axis: Reshape tensor along this axis.
86+
chunk_length: Split the axis into [axis/chunk_length, chunk_length]
87+
88+
Returns:
89+
Reshaped tensor.
90+
"""
91+
shape = tf.shape(tensor)
92+
num_chunks = shape[axis] // chunk_length
93+
new_shape = tf.concat(
94+
[shape[:axis], [num_chunks, chunk_length], shape[(axis+1):]], axis=0)
95+
return tf.reshape(tensor, new_shape)
96+
97+
98+
def windowed_causal_performer_attention(query_matrix,
99+
key_matrix,
100+
value_matrix,
101+
chunk_length,
102+
window_length,
103+
pad="right"):
104+
"""Applies windowed causal kernel attention with query, key, value tensors.
105+
106+
We partition the T-length input sequence into N chunks, each of chunk_length
107+
tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
108+
(non-causal) Performers’ implicit attention and we model relationships between
109+
different chunks using Performers’ causal attention. We consider windowed
110+
causal variant of performer, where the current chunk attends only to the
111+
window of window_length of the most recent chunks.
112+
113+
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
114+
attention is computed between the pair while 0 indicates attention is not
115+
computed between the pairs:
116+
111000000
117+
111000000
118+
111000000
119+
111111000
120+
111111000
121+
111111000
122+
000111111
123+
000111111
124+
000111111
125+
126+
User can ensure sequence_length is divisible by chunk_length or use
127+
pad="left"/"right" to pad the sequence length either at the top or bottom
128+
respectively and make it divisible by chunk_length.
129+
130+
Args:
131+
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
132+
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
133+
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
134+
chunk_length: Length of each chunk in tokens.
135+
window_length: Length of attention window in chunks.
136+
pad: Pad the query, value and key input tensors across the T dimension from
137+
left if pad="left", right if pad="right", or apply no padding if pad=None.
138+
In the latter case, the T dimension of the input tensors must be divisible
139+
by the chunk_length.
140+
141+
Returns:
142+
Window causal performer attention of shape `[B, T, N, out_dim]`.
143+
"""
144+
145+
old_shape = tf.shape(value_matrix)
146+
147+
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, pad)
148+
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, pad)
149+
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, pad)
150+
151+
new_shape = tf.shape(value_matrix)
152+
chunked_query_matrix = split_tensor_into_chunks(
153+
query_matrix, -3,
154+
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
155+
chunked_key_matrix = split_tensor_into_chunks(
156+
key_matrix, -3,
157+
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
158+
chunked_value_matrix = split_tensor_into_chunks(
159+
value_matrix, -3,
160+
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]
161+
162+
kp_v = tf.einsum("BNCHD,BNCHO->BNHDO", chunked_key_matrix,
163+
chunked_value_matrix)
164+
kp_v_cumsum = tf.cumsum(kp_v, axis=-4)
165+
kp_v_winsum = kp_v_cumsum - tf.pad(
166+
kp_v_cumsum,
167+
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
168+
numerator = tf.einsum("BNCHD,BNHDO->BNCHO", chunked_query_matrix, kp_v_winsum)
169+
170+
k_sum = tf.reduce_sum(chunked_key_matrix, axis=-3)
171+
k_cumsum = tf.cumsum(k_sum, axis=-3)
172+
k_winsum = k_cumsum - tf.pad(k_cumsum, [[0, 0], [window_length, 0], [0, 0],
173+
[0, 0]])[:, :-window_length]
174+
denominator = tf.einsum("BNCHD,BNHD->BNCH", chunked_query_matrix, k_winsum)
175+
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
176+
177+
attention = numerator / denominator
178+
attention = tf.reshape(attention, new_shape)
179+
180+
start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
181+
attention = tf.slice(attention, start, old_shape)
182+
183+
return attention
184+
185+
44186
def create_projection_matrix(m, d, seed=None):
45187
r"""Constructs the matrix of random projections.
46188
@@ -304,6 +446,9 @@ def __init__(self,
304446
begin_kernel=0,
305447
scale=None,
306448
scale_by_length=False,
449+
use_windowed_causal=False,
450+
chunk_length=1,
451+
window_length=3,
307452
**kwargs):
308453
r"""Constructor of KernelAttention.
309454
@@ -330,9 +475,14 @@ def __init__(self,
330475
the dot product based on key length. Set as log_512^(n) to stablize
331476
attention entropy against length. Refer to
332477
https://kexue.fm/archives/8823 for details.
478+
use_windowed_causal: If true perform windowed causal attention. See
479+
windowed_causal_performer_attention function docstring for more details.
480+
chunk_length: Length of each chunk in tokens.
481+
window_length: Length of attention window in chunks.
333482
**kwargs: The same arguments `MultiHeadAttention` layer.
334483
"""
335-
if feature_transform not in _TRANSFORM_MAP and feature_transform != "expplus":
484+
if (feature_transform not in _TRANSFORM_MAP and
485+
feature_transform != "expplus"):
336486
raise ValueError("Unsupported feature_transform. The supported "
337487
"feature_transform are %s. "
338488
"Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
@@ -359,6 +509,12 @@ def __init__(self,
359509
self._projection_matrix = create_projection_matrix(
360510
self._num_random_features, self._key_dim,
361511
tf.constant([self._seed, self._seed + 1]))
512+
self.use_windowed_causal = use_windowed_causal
513+
self.chunk_length = chunk_length
514+
self.window_length = window_length
515+
if self.use_windowed_causal and self._is_short_seq:
516+
raise ValueError(
517+
"use_windowed_causal and short_seq methods are mutually exclusive")
362518

363519
def _compute_attention(self,
364520
query,
@@ -394,6 +550,7 @@ def _compute_attention(self,
394550
attention_output: Multi-headed outputs of attention computation.
395551
"""
396552
projection_matrix = None
553+
397554
if self._num_random_features > 0:
398555
if self._redraw and training:
399556
projection_matrix = create_projection_matrix(self._num_random_features,
@@ -433,6 +590,9 @@ def _compute_attention(self,
433590
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
434591
attention_scores = tf.nn.softmax(attention_scores, axis=2)
435592
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
593+
elif self.use_windowed_causal:
594+
attention_output = windowed_causal_performer_attention(
595+
query_prime, key_prime, value, self.chunk_length, self.window_length)
436596
else:
437597
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
438598
denominator = 1.0 / (

official/nlp/modeling/layers/kernel_attention_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,39 @@ def test_attention_projection(
6060
training=training)
6161
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
6262

63+
@parameterized.parameters(
64+
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
65+
[0]))
66+
def test_windowed_causal_attention_projection(
67+
self, feature_transform, num_random_features, training, redraw,
68+
begin_kernel):
69+
num_heads = 12
70+
key_dim = 64
71+
seq_length = 1024
72+
batch_size = 2
73+
test_layer = attention.KernelAttention(
74+
num_heads=num_heads,
75+
key_dim=key_dim,
76+
feature_transform=feature_transform,
77+
num_random_features=num_random_features,
78+
redraw=redraw,
79+
is_short_seq=False,
80+
begin_kernel=begin_kernel,
81+
use_windowed_causal=True,
82+
chunk_length=8,
83+
window_length=3)
84+
query = tf.random.normal(
85+
shape=(batch_size, seq_length, key_dim))
86+
value = query
87+
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
88+
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
89+
output = test_layer(
90+
query=query,
91+
value=value,
92+
attention_mask=masks,
93+
training=training)
94+
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
95+
6396
@parameterized.parameters(itertools.product(
6497
_FEATURE_TRANSFORM, [0], _TRAINING, [False],
6598
_IS_SHORT_SEQ, _BEGIN_KERNEL))

0 commit comments

Comments
 (0)