@@ -41,17 +41,19 @@ def call(self, inputs, mask):
41
41
return mask
42
42
43
43
44
- def pad_to_chunk_length (tensor , axis , chunk_length , pad = "right" ):
44
+ def pad_to_chunk_length (tensor , axis , chunk_length , padding = None ):
45
45
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
46
46
47
47
Args:
48
48
tensor: Input tensor to pad.
49
49
axis: Axis to pad along.
50
50
chunk_length: The output tensor will have shape[axis] divisible by
51
51
chunk_length.
52
- pad: Pad the input tensor across the axis from left if pad="left", right if
53
- pad="right", or apply no padding if pad=None. In the latter case, the axis
54
- dimension of the input tensor must be divisible by the chunk_length.
52
+ padding: Pad the input tensor across the axis from either left or
53
+ right if padding is set to "left" or "right"; applies no padding
54
+ if padding is set to None. In the latter case, the axis
55
+ dimension of the input tensor must be divisible by the
56
+ chunk_length.
55
57
56
58
Returns:
57
59
Padded tensor with shape[axis] divisible by chunk_length.
@@ -62,19 +64,23 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
62
64
axis += rank
63
65
axis_length = shape [axis ]
64
66
pad_length = - axis_length % chunk_length
65
- if pad == "right" :
66
- pad_width_2 = [[0 , pad_length ]]
67
- elif pad == "left" :
68
- pad_width_2 = [[pad_length , 0 ]]
69
- else :
67
+ if padding == "right" :
68
+ axis_paddings = [[0 , pad_length ]]
69
+ elif padding == "left" :
70
+ axis_paddings = [[pad_length , 0 ]]
71
+ elif padding is None :
70
72
if pad_length != 0 :
71
- raise ValueError ("When padding is not set , the axis dimension"
73
+ raise ValueError ("When padding is None , the axis dimension"
72
74
"has to be divisible by the chunk_length." )
73
75
return tensor
74
- pad_width = tf .concat (
75
- [tf .zeros ([axis , 2 ], dtype = tf .int32 ), pad_width_2 ,
76
+ else :
77
+ raise ValueError ("Illegal padding value; must be one of \" left\" "
78
+ "\" right\" or None." )
79
+ paddings = tf .concat (
80
+ [tf .zeros ([axis , 2 ], dtype = tf .int32 ),
81
+ axis_paddings ,
76
82
tf .zeros ([rank - axis - 1 , 2 ], dtype = tf .int32 )], axis = 0 )
77
- return tf .pad (tensor , pad_width )
83
+ return tf .pad (tensor , paddings )
78
84
79
85
80
86
def split_tensor_into_chunks (tensor , axis , chunk_length ):
@@ -95,12 +101,12 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
95
101
return tf .reshape (tensor , new_shape )
96
102
97
103
98
- def windowed_causal_performer_attention (query_matrix ,
104
+ def causal_windowed_performer_attention (query_matrix ,
99
105
key_matrix ,
100
106
value_matrix ,
101
107
chunk_length ,
102
108
window_length ,
103
- pad = "right" ):
109
+ padding = None ):
104
110
"""Applies windowed causal kernel attention with query, key, value tensors.
105
111
106
112
We partition the T-length input sequence into N chunks, each of chunk_length
@@ -113,40 +119,40 @@ def windowed_causal_performer_attention(query_matrix,
113
119
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
114
120
attention is computed between the pair while 0 indicates attention is not
115
121
computed between the pairs:
116
- 111000000
117
- 111000000
118
- 111000000
119
- 111111000
120
- 111111000
121
- 111111000
122
- 000111111
123
- 000111111
124
- 000111111
122
+ 111000000
123
+ 111000000
124
+ 111000000
125
+ 111111000
126
+ 111111000
127
+ 111111000
128
+ 000111111
129
+ 000111111
130
+ 000111111
125
131
126
132
User can ensure sequence_length is divisible by chunk_length or use
127
- pad ="left"/"right" to pad the sequence length either at the top or bottom
128
- respectively and make it divisible by chunk_length.
133
+ padding ="left"/"right" to pad the sequence length either at the left
134
+ or right respectively and make it divisible by chunk_length.
129
135
130
136
Args:
131
137
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
132
138
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
133
139
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
134
140
chunk_length: Length of each chunk in tokens.
135
141
window_length: Length of attention window in chunks.
136
- pad: Pad the query, value and key input tensors across the T dimension from
137
- left if pad="left", right if pad="right", or apply no padding if pad=None.
138
- In the latter case, the T dimension of the input tensors must be divisible
139
- by the chunk_length.
142
+ padding: Pad the query, value and key input tensors across the
143
+ axis from either left or right if padding is set to "left" or
144
+ "right"; apply no padding if padding is set to None. In the
145
+ latter case, the axis dimension of the query, value and key
146
+ input tensors must be divisible by the chunk_length.
140
147
141
148
Returns:
142
149
Window causal performer attention of shape `[B, T, N, out_dim]`.
143
150
"""
144
-
145
151
old_shape = tf .shape (value_matrix )
146
152
147
- query_matrix = pad_to_chunk_length (query_matrix , - 3 , chunk_length , pad )
148
- key_matrix = pad_to_chunk_length (key_matrix , - 3 , chunk_length , pad )
149
- value_matrix = pad_to_chunk_length (value_matrix , - 3 , chunk_length , pad )
153
+ query_matrix = pad_to_chunk_length (query_matrix , - 3 , chunk_length , padding )
154
+ key_matrix = pad_to_chunk_length (key_matrix , - 3 , chunk_length , padding )
155
+ value_matrix = pad_to_chunk_length (value_matrix , - 3 , chunk_length , padding )
150
156
151
157
new_shape = tf .shape (value_matrix )
152
158
chunked_query_matrix = split_tensor_into_chunks (
@@ -446,16 +452,17 @@ def __init__(self,
446
452
begin_kernel = 0 ,
447
453
scale = None ,
448
454
scale_by_length = False ,
449
- use_windowed_causal = False ,
450
- chunk_length = 1 ,
451
- window_length = 3 ,
455
+ use_causal_windowed = False ,
456
+ causal_chunk_length = 1 ,
457
+ causal_window_length = 1 ,
458
+ causal_padding = None ,
452
459
** kwargs ):
453
460
r"""Constructor of KernelAttention.
454
461
455
462
Args:
456
- feature_transform: A non-linear transform of the keys and quries. Possible
457
- transforms are "elu", "relu", "square", "exp", "expplus", "expmod ",
458
- "identity".
463
+ feature_transform: A non-linear transform of the keys and queries.
464
+ Possible transforms are "elu", "relu", "square", "exp", "expplus",
465
+ "expmod", " identity".
459
466
num_random_features: Number of random features to be used for projection.
460
467
if num_random_features <= 0, no production is used before transform.
461
468
seed: The seed to begin drawing random features. Once the seed is set, the
@@ -475,11 +482,17 @@ def __init__(self,
475
482
the dot product based on key length. Set as log_512^(n) to stablize
476
483
attention entropy against length. Refer to
477
484
https://kexue.fm/archives/8823 for details.
478
- use_windowed_causal: If true perform windowed causal attention. See
479
- windowed_causal_performer_attention function docstring for more details.
480
- chunk_length: Length of each chunk in tokens.
481
- window_length: Length of attention window in chunks.
482
- **kwargs: The same arguments `MultiHeadAttention` layer.
485
+ use_causal_windowed: If true perform windowed causal attention. See
486
+ causal_windowed_performer_attention function docstring for more details.
487
+ causal_chunk_length: Length of each chunk in tokens.
488
+ causal_window_length: Length of attention window in chunks.
489
+ causal_padding: Pad the query, value and key input tensors
490
+ across the axis from either left or right if padding is set to
491
+ "left" or "right"; apply no padding if padding is set to None.
492
+ In the latter case, the axis dimension of the query, value and
493
+ key input tensors must be divisible by the chunk_length.
494
+ **kwargs:
495
+ The same arguments `MultiHeadAttention` layer.
483
496
"""
484
497
if (feature_transform not in _TRANSFORM_MAP and
485
498
feature_transform != "expplus" ):
@@ -509,12 +522,13 @@ def __init__(self,
509
522
self ._projection_matrix = create_projection_matrix (
510
523
self ._num_random_features , self ._key_dim ,
511
524
tf .constant ([self ._seed , self ._seed + 1 ]))
512
- self .use_windowed_causal = use_windowed_causal
513
- self .chunk_length = chunk_length
514
- self .window_length = window_length
515
- if self .use_windowed_causal and self ._is_short_seq :
525
+ self .use_causal_windowed = use_causal_windowed
526
+ self .causal_chunk_length = causal_chunk_length
527
+ self .causal_window_length = causal_window_length
528
+ self .causal_padding = causal_padding
529
+ if self .use_causal_windowed and self ._is_short_seq :
516
530
raise ValueError (
517
- "use_windowed_causal and short_seq methods are mutually exclusive" )
531
+ "use_causal_windowed and short_seq methods are mutually exclusive" )
518
532
519
533
def _compute_attention (self ,
520
534
query ,
@@ -590,9 +604,12 @@ def _compute_attention(self,
590
604
attention_scores = tf .einsum ("BTNH,BSNH->BTSN" , query_prime , key_prime )
591
605
attention_scores = tf .nn .softmax (attention_scores , axis = 2 )
592
606
attention_output = tf .einsum ("BTSN,BSNH->BTNH" , attention_scores , value )
593
- elif self .use_windowed_causal :
594
- attention_output = windowed_causal_performer_attention (
595
- query_prime , key_prime , value , self .chunk_length , self .window_length )
607
+ elif self .use_causal_windowed :
608
+ attention_output = causal_windowed_performer_attention (
609
+ query_prime , key_prime , value ,
610
+ chunk_length = self .causal_chunk_length ,
611
+ window_length = self .causal_window_length ,
612
+ padding = self .causal_padding )
596
613
else :
597
614
kv = tf .einsum ("BSNH,BSND->BNDH" , key_prime , value )
598
615
denominator = 1.0 / (
0 commit comments