Skip to content

Commit bc113f0

Browse files
No public description
PiperOrigin-RevId: 674517247
1 parent 1d9468e commit bc113f0

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

official/nlp/modeling/layers/transformer_encoder_block.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(self,
112112
num_kv_heads=None,
113113
src_block_size=None,
114114
tgt_block_size=None,
115+
use_sigmoid_attn=False,
116+
sigmoid_attn_bias=None,
115117
**kwargs):
116118
"""Initializes `TransformerEncoderBlock`.
117119
@@ -185,6 +187,10 @@ def __init__(self,
185187
`block_sparse_attention.MultiHeadAttention` for more details.
186188
tgt_block_size: Target block size. Refer to
187189
`block_sparse_attention.MultiHeadAttention` for more details.
190+
use_sigmoid_attn: This param is only used in
191+
`block_sparse_attention.MultiHeadAttention`
192+
sigmoid_attn_bias: This param is only used in
193+
`block_sparse_attention.MultiHeadAttention`
188194
**kwargs: keyword arguments.
189195
"""
190196
util.filter_kwargs(kwargs)
@@ -222,6 +228,8 @@ def __init__(self,
222228
self._num_kv_heads = num_kv_heads
223229
self._src_block_size = src_block_size
224230
self._tgt_block_size = tgt_block_size
231+
self._use_sigmoid_attn = use_sigmoid_attn
232+
self._sigmoid_attn_bias = sigmoid_attn_bias
225233
if self._num_kv_heads is not None and self._src_block_size is not None:
226234
raise ValueError(
227235
"Block sparse attention does not support Multi-query attention."
@@ -285,6 +293,8 @@ def build(self, input_shape):
285293
attention_layer_kwargs.update(
286294
src_block_size=self._src_block_size,
287295
tgt_block_size=self._tgt_block_size,
296+
use_sigmoid_attn=self._use_sigmoid_attn,
297+
sigmoid_attn_bias=self._sigmoid_attn_bias,
288298
name="block_sparse_attention",
289299
)
290300
attention_fn = block_sparse_attention.MultiHeadAttention
@@ -413,6 +423,8 @@ def get_config(self):
413423
"num_kv_heads": self._num_kv_heads,
414424
"src_block_size": self._src_block_size,
415425
"tgt_block_size": self._tgt_block_size,
426+
"use_sigmoid_attn": self._use_sigmoid_attn,
427+
"sigmoid_attn_bias": self._sigmoid_attn_bias,
416428
}
417429
base_config = super().get_config()
418430
return dict(list(base_config.items()) + list(config.items()))

official/nlp/modeling/layers/transformer_encoder_block_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for Keras-based transformer block layer."""
1616

17+
import math
18+
1719
from absl.testing import parameterized
1820
import numpy as np
1921
import tensorflow as tf, tf_keras
@@ -751,7 +753,11 @@ def test_attention_with_kv_heads(self, num_kv_heads):
751753
output_tensor[1].shape.as_list(), expected_attention_scores_shape
752754
)
753755

754-
def test_block_sparse_attention(self):
756+
@parameterized.named_parameters(
757+
('use_softmax_attn', False),
758+
('use_sigmoid_attn', True),
759+
)
760+
def test_block_sparse_attention(self, use_sigmoid_attn):
755761
num_attention_heads = 8
756762
sequence_length = 21
757763
width = 80
@@ -765,6 +771,10 @@ def test_block_sparse_attention(self):
765771
return_attention_scores=True,
766772
src_block_size=src_block_size,
767773
tgt_block_size=tgt_block_size,
774+
use_sigmoid_attn=use_sigmoid_attn,
775+
sigmoid_attn_bias=-math.log(sequence_length)
776+
if use_sigmoid_attn
777+
else None,
768778
)
769779
# Create a 3-dimensional input (the first dimension is implicit).
770780
data_tensor = tf_keras.Input(shape=(sequence_length, width))

0 commit comments

Comments
 (0)