|
14 | 14 |
|
15 | 15 | """Tests for block sparse attention layer."""
|
16 | 16 |
|
| 17 | +import math |
| 18 | + |
17 | 19 | from absl.testing import parameterized
|
18 | 20 | import numpy as np
|
19 | 21 | import tensorflow as tf, tf_keras
|
@@ -53,12 +55,29 @@ def test_non_masked_self_attention(self):
|
53 | 55 | output = test_layer(query, query)
|
54 | 56 | self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
55 | 57 |
|
56 |
| - @parameterized.named_parameters(("with_bias", True), ("no_bias", False)) |
57 |
| - def test_masked_attention(self, use_bias): |
| 58 | + @parameterized.named_parameters( |
| 59 | + ("with_bias", True), |
| 60 | + ("no_bias", False), |
| 61 | + ("with_sigmoid_attn", True, True), |
| 62 | + ) |
| 63 | + def test_masked_attention( |
| 64 | + self, |
| 65 | + use_bias, |
| 66 | + use_sigmoid_attn=False, |
| 67 | + ): |
58 | 68 | """Test with a mask tensor."""
|
| 69 | + if use_sigmoid_attn: |
| 70 | + sigmoid_attn_bias = -math.log(2) |
| 71 | + else: |
| 72 | + sigmoid_attn_bias = None |
59 | 73 | test_layer = block_sparse_attention.MultiHeadAttention(
|
60 |
| - num_heads=4, key_dim=2, use_bias=use_bias, src_block_size=2, |
| 74 | + num_heads=4, |
| 75 | + key_dim=2, |
| 76 | + use_bias=use_bias, |
| 77 | + src_block_size=2, |
61 | 78 | tgt_block_size=1,
|
| 79 | + use_sigmoid_attn=use_sigmoid_attn, |
| 80 | + sigmoid_attn_bias=sigmoid_attn_bias, |
62 | 81 | )
|
63 | 82 | # Create a 3-dimensional input (the first dimension is implicit).
|
64 | 83 | batch_size = 3
|
@@ -112,6 +131,77 @@ def test_masked_attention(self, use_bias):
|
112 | 131 | self.assertLen(test_layer._query_dense.trainable_variables, 1)
|
113 | 132 | self.assertLen(test_layer._output_dense.trainable_variables, 1)
|
114 | 133 |
|
| 134 | + @parameterized.named_parameters( |
| 135 | + ("default_with_softmax", False), |
| 136 | + ("default_with_sigmoid", True), |
| 137 | + ) |
| 138 | + def test_default_masked_attention( |
| 139 | + self, |
| 140 | + use_sigmoid_attn=False, |
| 141 | + ): |
| 142 | + """Test with a mask tensor.""" |
| 143 | + seq_len = 8 |
| 144 | + if use_sigmoid_attn: |
| 145 | + sigmoid_attn_bias = -math.log(seq_len) |
| 146 | + else: |
| 147 | + sigmoid_attn_bias = None |
| 148 | + test_layer = block_sparse_attention.MultiHeadAttention( |
| 149 | + num_heads=4, |
| 150 | + key_dim=2, |
| 151 | + use_bias=True, |
| 152 | + src_block_size=seq_len, |
| 153 | + tgt_block_size=seq_len, |
| 154 | + use_sigmoid_attn=use_sigmoid_attn, |
| 155 | + sigmoid_attn_bias=sigmoid_attn_bias, |
| 156 | + ) |
| 157 | + # Create a 3-dimensional input (the first dimension is implicit). |
| 158 | + batch_size = 3 |
| 159 | + query = tf_keras.Input(shape=(seq_len, 8)) |
| 160 | + value = tf_keras.Input(shape=(seq_len, 8)) |
| 161 | + mask_tensor = tf_keras.Input(shape=(seq_len, seq_len)) |
| 162 | + output = test_layer(query=query, value=value, attention_mask=mask_tensor) |
| 163 | + |
| 164 | + # Create a model containing the test layer. |
| 165 | + model = tf_keras.Model([query, value, mask_tensor], output) |
| 166 | + |
| 167 | + # Generate data for the input (non-mask) tensors. |
| 168 | + from_data = 10 * np.random.random_sample((batch_size, seq_len, 8)) |
| 169 | + to_data = 10 * np.random.random_sample((batch_size, seq_len, 8)) |
| 170 | + |
| 171 | + # Invoke the data with a random set of mask data. This should mask at |
| 172 | + # least one element. |
| 173 | + mask_data = np.random.randint(2, size=(batch_size, seq_len, seq_len)) |
| 174 | + masked_output_data = model.predict([from_data, to_data, mask_data]) |
| 175 | + |
| 176 | + # Invoke the same data, but with a null mask (where no elements are |
| 177 | + # masked). |
| 178 | + null_mask_data = np.ones((batch_size, seq_len, seq_len)) |
| 179 | + unmasked_output_data = model.predict([from_data, to_data, null_mask_data]) |
| 180 | + |
| 181 | + # Because one data is masked and one is not, the outputs should not be |
| 182 | + # the same. |
| 183 | + self.assertNotAllClose(masked_output_data, unmasked_output_data) |
| 184 | + |
| 185 | + # Tests the layer with three inputs: Q, K, V. |
| 186 | + key = tf_keras.Input(shape=(seq_len, 8)) |
| 187 | + output = test_layer( |
| 188 | + query, value=value, key=key, attention_mask=mask_tensor |
| 189 | + ) |
| 190 | + model = tf_keras.Model([query, value, key, mask_tensor], output) |
| 191 | + |
| 192 | + masked_output_data = model.predict( |
| 193 | + [from_data, to_data, to_data, mask_data] |
| 194 | + ) |
| 195 | + unmasked_output_data = model.predict( |
| 196 | + [from_data, to_data, to_data, null_mask_data] |
| 197 | + ) |
| 198 | + # Because one data is masked and one is not, the outputs should not be |
| 199 | + # the same. |
| 200 | + self.assertNotAllClose(masked_output_data, unmasked_output_data) |
| 201 | + |
| 202 | + self.assertLen(test_layer._query_dense.trainable_variables, 2) |
| 203 | + self.assertLen(test_layer._output_dense.trainable_variables, 2) |
| 204 | + |
115 | 205 | def test_masked_attention_with_scores(self):
|
116 | 206 | """Test with a mask tensor."""
|
117 | 207 | test_layer = block_sparse_attention.MultiHeadAttention(
|
|
0 commit comments