@@ -112,6 +112,8 @@ def __init__(self,
112
112
num_kv_heads = None ,
113
113
src_block_size = None ,
114
114
tgt_block_size = None ,
115
+ use_sigmoid_attn = False ,
116
+ sigmoid_attn_bias = None ,
115
117
** kwargs ):
116
118
"""Initializes `TransformerEncoderBlock`.
117
119
@@ -185,6 +187,10 @@ def __init__(self,
185
187
`block_sparse_attention.MultiHeadAttention` for more details.
186
188
tgt_block_size: Target block size. Refer to
187
189
`block_sparse_attention.MultiHeadAttention` for more details.
190
+ use_sigmoid_attn: This param is only used in
191
+ `block_sparse_attention.MultiHeadAttention`
192
+ sigmoid_attn_bias: This param is only used in
193
+ `block_sparse_attention.MultiHeadAttention`
188
194
**kwargs: keyword arguments.
189
195
"""
190
196
util .filter_kwargs (kwargs )
@@ -222,6 +228,8 @@ def __init__(self,
222
228
self ._num_kv_heads = num_kv_heads
223
229
self ._src_block_size = src_block_size
224
230
self ._tgt_block_size = tgt_block_size
231
+ self ._use_sigmoid_attn = use_sigmoid_attn
232
+ self ._sigmoid_attn_bias = sigmoid_attn_bias
225
233
if self ._num_kv_heads is not None and self ._src_block_size is not None :
226
234
raise ValueError (
227
235
"Block sparse attention does not support Multi-query attention."
@@ -285,6 +293,8 @@ def build(self, input_shape):
285
293
attention_layer_kwargs .update (
286
294
src_block_size = self ._src_block_size ,
287
295
tgt_block_size = self ._tgt_block_size ,
296
+ use_sigmoid_attn = self ._use_sigmoid_attn ,
297
+ sigmoid_attn_bias = self ._sigmoid_attn_bias ,
288
298
name = "block_sparse_attention" ,
289
299
)
290
300
attention_fn = block_sparse_attention .MultiHeadAttention
@@ -413,6 +423,8 @@ def get_config(self):
413
423
"num_kv_heads" : self ._num_kv_heads ,
414
424
"src_block_size" : self ._src_block_size ,
415
425
"tgt_block_size" : self ._tgt_block_size ,
426
+ "use_sigmoid_attn" : self ._use_sigmoid_attn ,
427
+ "sigmoid_attn_bias" : self ._sigmoid_attn_bias ,
416
428
}
417
429
base_config = super ().get_config ()
418
430
return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments