Skip to content

Commit 1d9468e

Browse files
No public description
PiperOrigin-RevId: 674413113
1 parent d817d87 commit 1d9468e

File tree

2 files changed

+166
-5
lines changed

2 files changed

+166
-5
lines changed

official/nlp/modeling/layers/block_sparse_attention.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,42 @@
1414

1515
"""Block sparse attention converts query/key/value into blocks and performs diagonal block sparse attention."""
1616
import collections
17+
import logging
1718

1819
import tensorflow as tf, tf_keras
1920

2021

22+
def _large_compatible_negative(tensor_type):
23+
"""Large negative number as Tensor.
24+
25+
This function is necessary because the standard value for epsilon
26+
in this module (-1e9) cannot be represented using tf.float16
27+
28+
Args:
29+
tensor_type: a dtype to determine the type.
30+
31+
Returns:
32+
a large negative number.
33+
"""
34+
# In case of dtype=float16 (e.g., for mixed-precision), the largest
35+
# negative number (dtypes.float16.min) is divided by 2, in order to
36+
# avoid overflows when summing negative inputs.
37+
if tensor_type == tf.float16:
38+
return tf.float16.min / 2.0
39+
return -1e9
40+
41+
2142
class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
2243
"""Multi-head block sparse attention layer."""
2344

24-
def __init__(self, src_block_size=None, tgt_block_size=None, **kwargs):
45+
def __init__(
46+
self,
47+
src_block_size=None,
48+
tgt_block_size=None,
49+
use_sigmoid_attn=False,
50+
sigmoid_attn_bias=None,
51+
**kwargs
52+
):
2553
"""Initializes the block sparse attention layer.
2654
2755
Args:
@@ -30,18 +58,34 @@ def __init__(self, src_block_size=None, tgt_block_size=None, **kwargs):
3058
tgt_block_size: The block size of the key/value. An integer that divides
3159
the sequence length into blocks. The number of blocks in the source and
3260
target must be the same.
61+
use_sigmoid_attn: If enabled, uses sigmoid instead of softmax to compute
62+
attn probs. https://arxiv.org/pdf/2409.04431
63+
sigmoid_attn_bias: Bias for sigmoid attn. Suggested value -ln(seq_len).
3364
**kwargs: Args passed to the base class.
3465
"""
3566
super().__init__(**kwargs)
3667
if src_block_size is None or src_block_size <= 0:
3768
raise ValueError("src_block_size must be specified.")
3869
self._src_block_size = src_block_size
3970
self._tgt_block_size = tgt_block_size or self._src_block_size
71+
self._use_sigmoid_attn = use_sigmoid_attn
72+
self._sigmoid_attn_bias = sigmoid_attn_bias
73+
if self._use_sigmoid_attn:
74+
if self._sigmoid_attn_bias is None:
75+
raise ValueError(
76+
"sigmoid_attn_bias must be specified for sigmoid attn."
77+
)
4078

4179
def _build_from_signature(self, query, value, key=None):
4280
# pytype: disable=attribute-error
4381
super()._build_from_signature(query, value, key)
4482
# pytype: enable=attribute-error
83+
# If block sizes are same as sequence lengths, we defer to default attn.
84+
if (
85+
self._query_shape[-2] == self._src_block_size
86+
and self._key_shape[-2] == self._tgt_block_size
87+
):
88+
return
4589
# The following capital letters are used to denote the tensor dimension
4690
# parameters:
4791
# B = batch size
@@ -127,11 +171,38 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
127171
if attention_mask is not None:
128172
# `attention_mask` = [B, 1, L, T, S]
129173
attention_mask = tf.expand_dims(attention_mask, axis=1)
130-
return self._softmax(attention_scores, attention_mask)
174+
if self._use_sigmoid_attn:
175+
if attention_mask is not None:
176+
adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * (
177+
_large_compatible_negative(attention_scores.dtype)
178+
)
179+
attention_scores += adder
180+
attention_scores += self._sigmoid_attn_bias
181+
return tf_keras.activations.sigmoid(attention_scores)
182+
else:
183+
return self._softmax(attention_scores, attention_mask)
131184

132185
def _compute_attention(
133186
self, query, key, value, attention_mask=None, training=None
134187
):
188+
# If block sizes are same as sequence lengths, we defer to default attn.
189+
if (
190+
self._query_shape[-2] == self._src_block_size
191+
and self._key_shape[-2] == self._tgt_block_size
192+
):
193+
logging.info(
194+
"Computing default attention as block sizes are equal to sequence"
195+
" lengths."
196+
)
197+
# pytype: disable=attribute-error
198+
return super()._compute_attention(
199+
query,
200+
key,
201+
value,
202+
attention_mask=attention_mask,
203+
training=training,
204+
)
205+
# pytype: enable=attribute-error
135206
# src_num_blocks and tgt_num_blocks are the number of blocks in the source
136207
# and target. Care should be taken to ensure that the number of blocks in
137208
# the source and target are the same.

official/nlp/modeling/layers/block_sparse_attention_test.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for block sparse attention layer."""
1616

17+
import math
18+
1719
from absl.testing import parameterized
1820
import numpy as np
1921
import tensorflow as tf, tf_keras
@@ -53,12 +55,29 @@ def test_non_masked_self_attention(self):
5355
output = test_layer(query, query)
5456
self.assertEqual(output.shape.as_list(), [None, 40, 80])
5557

56-
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
57-
def test_masked_attention(self, use_bias):
58+
@parameterized.named_parameters(
59+
("with_bias", True),
60+
("no_bias", False),
61+
("with_sigmoid_attn", True, True),
62+
)
63+
def test_masked_attention(
64+
self,
65+
use_bias,
66+
use_sigmoid_attn=False,
67+
):
5868
"""Test with a mask tensor."""
69+
if use_sigmoid_attn:
70+
sigmoid_attn_bias = -math.log(2)
71+
else:
72+
sigmoid_attn_bias = None
5973
test_layer = block_sparse_attention.MultiHeadAttention(
60-
num_heads=4, key_dim=2, use_bias=use_bias, src_block_size=2,
74+
num_heads=4,
75+
key_dim=2,
76+
use_bias=use_bias,
77+
src_block_size=2,
6178
tgt_block_size=1,
79+
use_sigmoid_attn=use_sigmoid_attn,
80+
sigmoid_attn_bias=sigmoid_attn_bias,
6281
)
6382
# Create a 3-dimensional input (the first dimension is implicit).
6483
batch_size = 3
@@ -112,6 +131,77 @@ def test_masked_attention(self, use_bias):
112131
self.assertLen(test_layer._query_dense.trainable_variables, 1)
113132
self.assertLen(test_layer._output_dense.trainable_variables, 1)
114133

134+
@parameterized.named_parameters(
135+
("default_with_softmax", False),
136+
("default_with_sigmoid", True),
137+
)
138+
def test_default_masked_attention(
139+
self,
140+
use_sigmoid_attn=False,
141+
):
142+
"""Test with a mask tensor."""
143+
seq_len = 8
144+
if use_sigmoid_attn:
145+
sigmoid_attn_bias = -math.log(seq_len)
146+
else:
147+
sigmoid_attn_bias = None
148+
test_layer = block_sparse_attention.MultiHeadAttention(
149+
num_heads=4,
150+
key_dim=2,
151+
use_bias=True,
152+
src_block_size=seq_len,
153+
tgt_block_size=seq_len,
154+
use_sigmoid_attn=use_sigmoid_attn,
155+
sigmoid_attn_bias=sigmoid_attn_bias,
156+
)
157+
# Create a 3-dimensional input (the first dimension is implicit).
158+
batch_size = 3
159+
query = tf_keras.Input(shape=(seq_len, 8))
160+
value = tf_keras.Input(shape=(seq_len, 8))
161+
mask_tensor = tf_keras.Input(shape=(seq_len, seq_len))
162+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
163+
164+
# Create a model containing the test layer.
165+
model = tf_keras.Model([query, value, mask_tensor], output)
166+
167+
# Generate data for the input (non-mask) tensors.
168+
from_data = 10 * np.random.random_sample((batch_size, seq_len, 8))
169+
to_data = 10 * np.random.random_sample((batch_size, seq_len, 8))
170+
171+
# Invoke the data with a random set of mask data. This should mask at
172+
# least one element.
173+
mask_data = np.random.randint(2, size=(batch_size, seq_len, seq_len))
174+
masked_output_data = model.predict([from_data, to_data, mask_data])
175+
176+
# Invoke the same data, but with a null mask (where no elements are
177+
# masked).
178+
null_mask_data = np.ones((batch_size, seq_len, seq_len))
179+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
180+
181+
# Because one data is masked and one is not, the outputs should not be
182+
# the same.
183+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
184+
185+
# Tests the layer with three inputs: Q, K, V.
186+
key = tf_keras.Input(shape=(seq_len, 8))
187+
output = test_layer(
188+
query, value=value, key=key, attention_mask=mask_tensor
189+
)
190+
model = tf_keras.Model([query, value, key, mask_tensor], output)
191+
192+
masked_output_data = model.predict(
193+
[from_data, to_data, to_data, mask_data]
194+
)
195+
unmasked_output_data = model.predict(
196+
[from_data, to_data, to_data, null_mask_data]
197+
)
198+
# Because one data is masked and one is not, the outputs should not be
199+
# the same.
200+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
201+
202+
self.assertLen(test_layer._query_dense.trainable_variables, 2)
203+
self.assertLen(test_layer._output_dense.trainable_variables, 2)
204+
115205
def test_masked_attention_with_scores(self):
116206
"""Test with a mask tensor."""
117207
test_layer = block_sparse_attention.MultiHeadAttention(

0 commit comments

Comments
 (0)