@@ -41,6 +41,148 @@ def call(self, inputs, mask):
41
41
return mask
42
42
43
43
44
+ def pad_to_chunk_length (tensor , axis , chunk_length , pad = "right" ):
45
+ """Pads a tensor so that shape[axis] is divisible by chunk_length.
46
+
47
+ Args:
48
+ tensor: Input tensor to pad.
49
+ axis: Axis to pad along.
50
+ chunk_length: The output tensor will have shape[axis] divisible by
51
+ chunk_length.
52
+ pad: Pad the input tensor across the axis from left if pad="left", right if
53
+ pad="right", or apply no padding if pad=None. In the latter case, the axis
54
+ dimension of the input tensor must be divisible by the chunk_length.
55
+
56
+ Returns:
57
+ Padded tensor with shape[axis] divisible by chunk_length.
58
+ """
59
+ shape = tf .shape (tensor )
60
+ rank = tf .rank (tensor )
61
+ if axis < 0 :
62
+ axis += rank
63
+ axis_length = shape [axis ]
64
+ pad_length = - axis_length % chunk_length
65
+ if pad == "right" :
66
+ pad_width_2 = [[0 , pad_length ]]
67
+ elif pad == "left" :
68
+ pad_width_2 = [[pad_length , 0 ]]
69
+ else :
70
+ if pad_length != 0 :
71
+ raise ValueError ("When padding is not set, the axis dimension"
72
+ "has to be divisible by the chunk_length." )
73
+ return tensor
74
+ pad_width = tf .concat (
75
+ [tf .zeros ([axis , 2 ], dtype = tf .int32 ), pad_width_2 ,
76
+ tf .zeros ([rank - axis - 1 , 2 ], dtype = tf .int32 )], axis = 0 )
77
+ return tf .pad (tensor , pad_width )
78
+
79
+
80
+ def split_tensor_into_chunks (tensor , axis , chunk_length ):
81
+ """Reshape tensor along given axis using chunk_length.
82
+
83
+ Args:
84
+ tensor: Input tensor.
85
+ axis: Reshape tensor along this axis.
86
+ chunk_length: Split the axis into [axis/chunk_length, chunk_length]
87
+
88
+ Returns:
89
+ Reshaped tensor.
90
+ """
91
+ shape = tf .shape (tensor )
92
+ num_chunks = shape [axis ] // chunk_length
93
+ new_shape = tf .concat (
94
+ [shape [:axis ], [num_chunks , chunk_length ], shape [(axis + 1 ):]], axis = 0 )
95
+ return tf .reshape (tensor , new_shape )
96
+
97
+
98
+ def windowed_causal_performer_attention (query_matrix ,
99
+ key_matrix ,
100
+ value_matrix ,
101
+ chunk_length ,
102
+ window_length ,
103
+ pad = "right" ):
104
+ """Applies windowed causal kernel attention with query, key, value tensors.
105
+
106
+ We partition the T-length input sequence into N chunks, each of chunk_length
107
+ tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
108
+ (non-causal) Performers’ implicit attention and we model relationships between
109
+ different chunks using Performers’ causal attention. We consider windowed
110
+ causal variant of performer, where the current chunk attends only to the
111
+ window of window_length of the most recent chunks.
112
+
113
+ Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
114
+ attention is computed between the pair while 0 indicates attention is not
115
+ computed between the pairs:
116
+ 111000000
117
+ 111000000
118
+ 111000000
119
+ 111111000
120
+ 111111000
121
+ 111111000
122
+ 000111111
123
+ 000111111
124
+ 000111111
125
+
126
+ User can ensure sequence_length is divisible by chunk_length or use
127
+ pad="left"/"right" to pad the sequence length either at the top or bottom
128
+ respectively and make it divisible by chunk_length.
129
+
130
+ Args:
131
+ query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
132
+ key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
133
+ value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
134
+ chunk_length: Length of each chunk in tokens.
135
+ window_length: Length of attention window in chunks.
136
+ pad: Pad the query, value and key input tensors across the T dimension from
137
+ left if pad="left", right if pad="right", or apply no padding if pad=None.
138
+ In the latter case, the T dimension of the input tensors must be divisible
139
+ by the chunk_length.
140
+
141
+ Returns:
142
+ Window causal performer attention of shape `[B, T, N, out_dim]`.
143
+ """
144
+
145
+ old_shape = tf .shape (value_matrix )
146
+
147
+ query_matrix = pad_to_chunk_length (query_matrix , - 3 , chunk_length , pad )
148
+ key_matrix = pad_to_chunk_length (key_matrix , - 3 , chunk_length , pad )
149
+ value_matrix = pad_to_chunk_length (value_matrix , - 3 , chunk_length , pad )
150
+
151
+ new_shape = tf .shape (value_matrix )
152
+ chunked_query_matrix = split_tensor_into_chunks (
153
+ query_matrix , - 3 ,
154
+ chunk_length ) # [-1, T//chunk_length, chunk_length, N, dim]
155
+ chunked_key_matrix = split_tensor_into_chunks (
156
+ key_matrix , - 3 ,
157
+ chunk_length ) # [-1, T//chunk_length, chunk_length, N, dim]
158
+ chunked_value_matrix = split_tensor_into_chunks (
159
+ value_matrix , - 3 ,
160
+ chunk_length ) # [-1, T//chunk_length, chunk_length, N, out_dim]
161
+
162
+ kp_v = tf .einsum ("BNCHD,BNCHO->BNHDO" , chunked_key_matrix ,
163
+ chunked_value_matrix )
164
+ kp_v_cumsum = tf .cumsum (kp_v , axis = - 4 )
165
+ kp_v_winsum = kp_v_cumsum - tf .pad (
166
+ kp_v_cumsum ,
167
+ [[0 , 0 ], [window_length , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]])[:, :- window_length ]
168
+ numerator = tf .einsum ("BNCHD,BNHDO->BNCHO" , chunked_query_matrix , kp_v_winsum )
169
+
170
+ k_sum = tf .reduce_sum (chunked_key_matrix , axis = - 3 )
171
+ k_cumsum = tf .cumsum (k_sum , axis = - 3 )
172
+ k_winsum = k_cumsum - tf .pad (k_cumsum , [[0 , 0 ], [window_length , 0 ], [0 , 0 ],
173
+ [0 , 0 ]])[:, :- window_length ]
174
+ denominator = tf .einsum ("BNCHD,BNHD->BNCH" , chunked_query_matrix , k_winsum )
175
+ denominator = tf .expand_dims (denominator , - 1 ) + _NUMERIC_STABLER
176
+
177
+ attention = numerator / denominator
178
+ attention = tf .reshape (attention , new_shape )
179
+
180
+ start = tf .zeros ([len (old_shape )], dtype = old_shape .dtype )
181
+ attention = tf .slice (attention , start , old_shape )
182
+
183
+ return attention
184
+
185
+
44
186
def create_projection_matrix (m , d , seed = None ):
45
187
r"""Constructs the matrix of random projections.
46
188
@@ -304,6 +446,9 @@ def __init__(self,
304
446
begin_kernel = 0 ,
305
447
scale = None ,
306
448
scale_by_length = False ,
449
+ use_windowed_causal = False ,
450
+ chunk_length = 1 ,
451
+ window_length = 3 ,
307
452
** kwargs ):
308
453
r"""Constructor of KernelAttention.
309
454
@@ -330,9 +475,14 @@ def __init__(self,
330
475
the dot product based on key length. Set as log_512^(n) to stablize
331
476
attention entropy against length. Refer to
332
477
https://kexue.fm/archives/8823 for details.
478
+ use_windowed_causal: If true perform windowed causal attention. See
479
+ windowed_causal_performer_attention function docstring for more details.
480
+ chunk_length: Length of each chunk in tokens.
481
+ window_length: Length of attention window in chunks.
333
482
**kwargs: The same arguments `MultiHeadAttention` layer.
334
483
"""
335
- if feature_transform not in _TRANSFORM_MAP and feature_transform != "expplus" :
484
+ if (feature_transform not in _TRANSFORM_MAP and
485
+ feature_transform != "expplus" ):
336
486
raise ValueError ("Unsupported feature_transform. The supported "
337
487
"feature_transform are %s. "
338
488
"Got '%s'." % (_TRANSFORM_MAP .keys (), feature_transform ))
@@ -359,6 +509,12 @@ def __init__(self,
359
509
self ._projection_matrix = create_projection_matrix (
360
510
self ._num_random_features , self ._key_dim ,
361
511
tf .constant ([self ._seed , self ._seed + 1 ]))
512
+ self .use_windowed_causal = use_windowed_causal
513
+ self .chunk_length = chunk_length
514
+ self .window_length = window_length
515
+ if self .use_windowed_causal and self ._is_short_seq :
516
+ raise ValueError (
517
+ "use_windowed_causal and short_seq methods are mutually exclusive" )
362
518
363
519
def _compute_attention (self ,
364
520
query ,
@@ -394,6 +550,7 @@ def _compute_attention(self,
394
550
attention_output: Multi-headed outputs of attention computation.
395
551
"""
396
552
projection_matrix = None
553
+
397
554
if self ._num_random_features > 0 :
398
555
if self ._redraw and training :
399
556
projection_matrix = create_projection_matrix (self ._num_random_features ,
@@ -433,6 +590,9 @@ def _compute_attention(self,
433
590
attention_scores = tf .einsum ("BTNH,BSNH->BTSN" , query_prime , key_prime )
434
591
attention_scores = tf .nn .softmax (attention_scores , axis = 2 )
435
592
attention_output = tf .einsum ("BTSN,BSNH->BTNH" , attention_scores , value )
593
+ elif self .use_windowed_causal :
594
+ attention_output = windowed_causal_performer_attention (
595
+ query_prime , key_prime , value , self .chunk_length , self .window_length )
436
596
else :
437
597
kv = tf .einsum ("BSNH,BSND->BNDH" , key_prime , value )
438
598
denominator = 1.0 / (
0 commit comments