@@ -114,6 +114,7 @@ def __init__(self,
114
114
tgt_block_size = None ,
115
115
use_sigmoid_attn = False ,
116
116
sigmoid_attn_bias = None ,
117
+ linformer_dim = None ,
117
118
** kwargs ):
118
119
"""Initializes `TransformerEncoderBlock`.
119
120
@@ -191,6 +192,8 @@ def __init__(self,
191
192
`block_sparse_attention.MultiHeadAttention`
192
193
sigmoid_attn_bias: This param is only used in
193
194
`block_sparse_attention.MultiHeadAttention`
195
+ linformer_dim: Applies low-rank factorization on keys/values as in
196
+ https://arxiv.org/pdf/2006.04768.
194
197
**kwargs: keyword arguments.
195
198
"""
196
199
util .filter_kwargs (kwargs )
@@ -230,6 +233,7 @@ def __init__(self,
230
233
self ._tgt_block_size = tgt_block_size
231
234
self ._use_sigmoid_attn = use_sigmoid_attn
232
235
self ._sigmoid_attn_bias = sigmoid_attn_bias
236
+ self ._linformer_dim = linformer_dim
233
237
if self ._num_kv_heads is not None and self ._src_block_size is not None :
234
238
raise ValueError (
235
239
"Block sparse attention does not support Multi-query attention."
@@ -366,16 +370,31 @@ def build(self, input_shape):
366
370
name = "output" ,
367
371
kernel_initializer = tf_utils .clone_initializer (self ._kernel_initializer ),
368
372
bias_initializer = tf_utils .clone_initializer (self ._bias_initializer ),
369
- ** common_kwargs )
373
+ ** common_kwargs ,
374
+ )
370
375
self ._output_dropout = tf_keras .layers .Dropout (
371
- rate = self ._output_dropout_rate )
376
+ rate = self ._output_dropout_rate
377
+ )
372
378
# Use float32 in layernorm for numeric stability.
373
379
self ._output_layer_norm = tf_keras .layers .LayerNormalization (
374
380
name = "output_layer_norm" ,
375
381
axis = - 1 ,
376
382
epsilon = self ._norm_epsilon ,
377
- dtype = tf .float32 )
378
-
383
+ dtype = tf .float32 ,
384
+ )
385
+ if self ._linformer_dim is not None :
386
+ # Current implementation uses the same weights for keys and values.
387
+ # TODO(akandoor): Explore using different weights for keys and values.
388
+ self ._lowrank_kv_projection = tf_keras .layers .EinsumDense (
389
+ "...bc,cd->...bd" ,
390
+ output_shape = (None , self ._linformer_dim ),
391
+ kernel_initializer = tf_utils .clone_initializer (
392
+ self ._kernel_initializer
393
+ ),
394
+ bias_initializer = tf_utils .clone_initializer (self ._bias_initializer ),
395
+ name = "lowrank_kv_projection" ,
396
+ ** common_kwargs ,
397
+ )
379
398
super ().build (input_shape )
380
399
381
400
def get_config (self ):
@@ -480,6 +499,19 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
480
499
if key_value is None :
481
500
key_value = input_tensor
482
501
502
+ if self ._linformer_dim is not None :
503
+ if attention_mask is not None :
504
+ # Applying mask before the low rank factorization so that padding is
505
+ # accounted for.
506
+ query_mask = tf .cast (attention_mask [:, :, 0 ], dtype = target_tensor .dtype )
507
+ target_tensor = target_tensor * tf .expand_dims (query_mask , axis = - 1 )
508
+ key_mask = tf .cast (attention_mask [:, 0 , :], dtype = target_tensor .dtype )
509
+ key_value = key_value * tf .expand_dims (key_mask , axis = - 1 )
510
+ attention_mask = None
511
+ key_value = tf .transpose (key_value , [0 , 2 , 1 ])
512
+ key_value = self ._lowrank_kv_projection (key_value )
513
+ key_value = tf .transpose (key_value , [0 , 2 , 1 ])
514
+
483
515
if self ._return_attention_scores :
484
516
attention_output , attention_scores = self ._attention_layer (
485
517
query = target_tensor ,
0 commit comments