Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 1b0af45

Browse files
authored
Create 20200616-keras-multihead-attention.md
The feedback phase will be open for two weeks until Wednesday July 02, 2020. # RFC: Multihead Attention and EinsumDense on Keras | Status | (Proposed / Accepted / Implemented / Obsolete) | | :------------ | :------------------------------------------------------ | | **RFC #** | [NNN](https://github.com/tensorflow/community/pull/NNN) | : : (update when you have community PR #) : | **Author(s)** | Hongkun Yu ([email protected]) | | **Sponsor** | Francois Chollet ([email protected]) | | **Updated** | 2020-06-16 | ## Objective Introduce the MultiHeadAttention layer and EinsumDense layer to tf.keras.
1 parent 462ff49 commit 1b0af45

File tree

1 file changed

+341
-0
lines changed

1 file changed

+341
-0
lines changed
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# RFC: Multihead Attention and EinsumDense on Keras
2+
3+
| Status | (Proposed / Accepted / Implemented / Obsolete) |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [NNN](https://github.com/tensorflow/community/pull/NNN) |
6+
: : (update when you have community PR #) :
7+
| **Author(s)** | Hongkun Yu ([email protected]) |
8+
| **Sponsor** | Francois Chollet ([email protected]) |
9+
| **Updated** | 2020-06-16 |
10+
11+
## Objective
12+
13+
Introduce the MultiHeadAttention layer and EinsumDense layer to tf.keras.
14+
15+
## Motivation
16+
17+
MultiHeadAttention is very popular and has become standard for deep learning
18+
libraries. We propose to contribute a flexible well-defined implementation
19+
inside Keras absorbing common best practices from reference libraries.
20+
21+
## User Benefit
22+
23+
We can standardize the implementation of Transformer layers and use the best
24+
practice. We offer a rich set of functionalities to different use cases, e.g.
25+
different project spaces, outputing multi-head attention scores for analysis,
26+
etc. We also modularize computations to make the MultiHeadAttention layer
27+
extensible to variants.
28+
29+
## Design Proposal
30+
31+
### Key Features
32+
33+
* Returns multi-headed attention scores, which is commonly useful for
34+
attention visualization and analysis.
35+
* Supports query (Q), key (K), value (V) tensors as individual inputs and
36+
supports projecting Q, K, V to different dimensions.
37+
* Final outputs projects to user specified dimensions.
38+
* Using tf.einsum to express high-dimensional computation and adopts
39+
[tf.keras.layers.experimental.EinsumDense](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/EinsumDense)
40+
layer.
41+
* Supports high-dimension attention when target and source are 2D, 3D, etc.
42+
43+
### Code Examples
44+
45+
* How to write a TransformerBlock for an encoder.
46+
47+
```python
48+
class TransformerBlock(tf.keras.layers.Layer):
49+
def __init__(self, embed_dim, num_heads, ff_dim):
50+
super(TransformerBlock, self).__init__()
51+
self.att = attention.MultiHeadAttention(embed_dim, num_heads)
52+
self.ffn = tf.keras.Sequential(
53+
[tf.keras.layers.Dense(ff_dim, activation="relu"),
54+
tf.keras.layers.Dense(embed_dim),]
55+
)
56+
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
57+
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
58+
59+
def call(self, inputs, attention_mask=None):
60+
attn_output = self.att([inputs, inputs], attention_mask=attention_mask)
61+
out1 = self.layernorm1(inputs + attn_output)
62+
ffn_output = self.ffn(out1)
63+
return self.layernorm2(out1 + ffn_output)
64+
```
65+
66+
* Use attention mask to avoid performing attention on padding token indices.
67+
68+
```python
69+
test_layer = TransformerBlock(
70+
embed_dim=2,
71+
num_heads=2,
72+
ff_dim=4)
73+
query = np.array([[[0.1, 0.2], [0.0, 0.0]]])
74+
mask = np.array([[[1, 0], [1, 0]]], dtype='bool')
75+
output = test_layer(query, mask)
76+
```
77+
78+
* Inside a Transformer decoder, we often want to output the cross-attention
79+
scores to analyze how the target sequence attend to the source sequence. We
80+
are able to visualize the alignment according to attention scores.
81+
82+
```python
83+
test_layer = MultiHeadAttention(
84+
num_heads=2, key_size=2, return_attention_scores=True)
85+
target = np.array([[[0.1, 0.2], [0.0, 0.0]]])
86+
source = np.array([[[0.1, 0.2], [3.0, 1.0]]])
87+
output, scores = test_layer([target, source])
88+
scores = tf.math.reduce_sum(scores, axis=1) # shape = (1, 2, 2)
89+
```
90+
91+
* Attention beyound sequences. Taking 2D, 3D target and source.
92+
93+
```python
94+
query_shape = [2, 3, 4, 4] # batch, target, target, embedding.
95+
value_shape = [2, 3, 2, 4] # batch, source, source, embedding.
96+
mask_shape = [2, 3, 4, 3, 2]
97+
query = 10 * np.random.random_sample(query_shape)
98+
value = 10 * np.random.random_sample(value_shape)
99+
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
100+
output = test_layer([query, value], mask_data)
101+
```
102+
103+
### Interface
104+
105+
```python
106+
class MultiHeadAttention(tf.keras.layers.Layer):
107+
"""MultiHeadAttention layer.
108+
109+
This is an implementation of multi-headed attention based on "Attention
110+
is all you Need". If `query`, `key,` `value` are the same, then
111+
this is self-attention. Each timestep in `query` attends to the
112+
corresponding sequence in `key`, and returns a fixed-width vector.
113+
114+
This layer first projects `query`, `key` and `value`. These are
115+
(effectively) a list of tensors of length `num_attention_heads`, where the
116+
corresponding shapes are [batch_size, <query dimensions>, key_size],
117+
[batch_size, <key/value dimensions>, key_size],
118+
[batch_size, <key/value dimensions>, value_size].
119+
120+
Then, the query and key tensors are dot-producted and scaled. These are
121+
softmaxed to obtain attention probabilities. The value tensors are then
122+
interpolated by these probabilities, then concatenated back to a single
123+
tensor.
124+
125+
Finally, the result tensor with the last dimension as value_size can take an
126+
linear projection and return.
127+
128+
Examples:
129+
130+
Performs 1D cross-attention over two sequence inputs with an attention mask.
131+
Returns the additional attention weights over heads.
132+
133+
>>> layer = MultiHeadAttention(num_heads=2, key_size=2,
134+
... return_attention_scores=True)
135+
>>> target = tf.keras.Input(shape=[8, 16])
136+
>>> source = tf.keras.Input(shape=[4, 16])
137+
>>> mask_tensor = tf.keras.Input(shape=[8, 4])
138+
>>> output_tensor, weights = layer([target, source])
139+
>>> print(output_tensor.shape), print(weights.shape)
140+
(None, 8, 16) (None, 2, 8, 4)
141+
142+
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
143+
144+
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3))
145+
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
146+
>>> output_tensor = layer([input_tensor, input_tensor])
147+
>>> print(output_tensor.shape)
148+
(None, 5, 3, 4, 16)
149+
150+
Arguments:
151+
num_heads: Number of attention heads.
152+
key_size: Size of each attention head for query and key.
153+
value_size: Size of each attention head for value.
154+
dropout: Dropout probability for a Dropout layer on attention_scores.
155+
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
156+
output_shape: The expected shape of an output tensor, besides the batch and
157+
sequence dims. If not specified, projects back to the key feature dim.
158+
attention_axes: axes over which the attention is applied. `None` means
159+
attention over all axes, but batch, heads, and features.
160+
return_attention_scores: bool, if `True`, returns the multi-head
161+
attention scores as an additional output argument.
162+
kernel_initializer: Initializer for dense layer kernels.
163+
bias_initializer: Initializer for dense layer biases.
164+
kernel_regularizer: Regularizer for dense layer kernels.
165+
bias_regularizer: Regularizer for dense layer biases.
166+
activity_regularizer: Regularizer for dense layer activity.
167+
kernel_constraint: Constraint for dense layer kernels.
168+
bias_constraint: Constraint for dense layer kernels.
169+
"""
170+
171+
def call(self, inputs, attention_mask=None):
172+
"""Implements the forward pass.
173+
174+
Size glossary:
175+
* Number of heads (H): the number of attention heads.
176+
* Value size (V): the size of each value embedding per head.
177+
* Key size (K): the size of each key embedding per head. Equally, the size
178+
of each query embedding per head. Typically K <= V.
179+
* Batch dimensions (B).
180+
* Query (target) attention axes shape (T).
181+
* Value (source) attention axes shape (S), the rank must match the target.
182+
183+
Args:
184+
inputs: List of the following tensors:
185+
* query: Query `Tensor` of shape `[B, T, dim]`.
186+
* value: Value `Tensor` of shape `[B, S, dim]`.
187+
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
188+
use `value` for both `key` and `value`, which is the most common case.
189+
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
190+
attention to certain positions.
191+
192+
Returns:
193+
attention_output: The result of the computation, of shape [B, T, E],
194+
where `T` is for target sequence shapes and `E` is the query input last
195+
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
196+
are project to the shape specified by `output_shape`.
197+
attention_scores: [Optional] multi-head attention coeffients over
198+
attention axes.
199+
"""
200+
```
201+
202+
### Auxiliary Layers and Changes
203+
204+
* EinsumDense layer
205+
206+
We use `tf.einsum` to implement a dense layer can perform einsum calculations of
207+
arbitrary dimensionality. This example shows how to instantiate a layer that
208+
applies the same dense operation to every element in a sequence. Here, the
209+
'output_shape' has two values (since there are two non-batch dimensions in the
210+
output); the first dimension in the output_shape is `None`, because the sequence
211+
dimension `b` has an unknown shape.
212+
213+
```python
214+
layer = EinsumDense("abc,cd->abd", output_shape=(None, 64), bias_axes="d")
215+
input_tensor = tf.keras.Input(shape=[32, 128])
216+
output_tensor = layer(input_tensor) # output shape is (None, 32, 64)
217+
```
218+
219+
* Masked Softmax
220+
221+
Inside the attention computation, we need to mask logits before softmax and it
222+
has become a common treatment in many applications. We propose to add an
223+
optional `mask` argument to `tf.nn.softmax`. The downstream keras `Softmax`
224+
layer will also take an optional `mask` tensor. This `mask` tensor should have
225+
the same rank as the input tensor and mask elements on the axis which will
226+
perform softmax.
227+
228+
Inside `MultiHeadAttention` keras layer, we will use the keras `Softmax` layer
229+
with mask and adjust attention mask shape to match the inputs. The dimension
230+
expension logic and multi-axes softmax will be handled locally in
231+
`MultiHeadAttention` layer.
232+
233+
* Keras Dense Attention
234+
235+
[tf.keras.layers.Attention](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention)
236+
layer call method takes an optional argument, `mask`, which requires two
237+
tensors, `q_mask` and `v_mask`. They are following keras framework requirements
238+
with (batch_size, target_length) and (batch_size, source_length) as shapes. This
239+
limits the flexibility of masking and `MultiHeadAttention` layer generalize the
240+
attention mask to be (batch dims, target dims, source dims). To be consistent,
241+
we would like to introduce an optional argument `attention_mask` for
242+
`tf.keras.layers.Attention`. In the reduced case of `tf.keras.layers.Attention`,
243+
the shape is (batch_size, target_length, source_length). Whenever
244+
`attention_mask` is specified, the `mask` argument is OK to be skipped.
245+
246+
### Alternatives Considered
247+
248+
We examined multi-head attention layer implemented in various libraries. There
249+
are a few features that we do not include inside this keras layer and we feel it
250+
is better to subclass the `MultiHeadAttention` layer to fulfill the needs.
251+
252+
* Attention caching for decoding. Implemented in
253+
[Flax](https://github.com/google/flax/blob/master/flax/nn/attention.py#L301).
254+
The caching is a special treatment for inference and we noticied that
255+
different treatments are required for dynamic or static shape programs.
256+
Thus, subclassing as a
257+
[CachedAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py)
258+
layer is the solution inside the model garden.
259+
* [MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py)
260+
keras layer is also implemented in TF-Addons. The design in this doc covers
261+
the features in TF-addons implementation but generalizes to more use cases.
262+
263+
### Performance Implications
264+
265+
* We will add microbenchmarks following the common practices of keras layers.
266+
* We have end-to-end integration/regression tests for models using this layer,
267+
e.g. BERT.
268+
269+
### Dependencies
270+
271+
No dependencies.
272+
273+
### Engineering Impact
274+
275+
* The keras layer can be tested inside the package.
276+
* TensorFlow team will maintain the code.
277+
278+
### Platforms and Environments
279+
280+
* Work for all platforms and environments
281+
282+
### Best Practices
283+
284+
* No change for Tensorflow best practices.
285+
286+
### Tutorials and Examples
287+
288+
* Code examples can be found inside Tensorflow Model Garden. For example, an
289+
encoder
290+
[Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer.py).
291+
292+
* 2D attention example in the
293+
[unit test](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention_test.py#L135).
294+
295+
### Compatibility
296+
297+
* This is a new layer without compatibility concerns.
298+
* The proposal works with TFLite, distribution strategy, tf.function, GPU/TPU
299+
and serializable to SavedModel. These are tested inside TensorFlow Model
300+
Garden applications.
301+
302+
### User Impact
303+
304+
* We will first introduce the layer as
305+
`tf.keras.layers.experimental.MultiHeadAttention` and
306+
`tf.keras.layers.experimental.MaskedSoftmax`.
307+
308+
## Detailed Design
309+
310+
The layer has been implemented as the
311+
[MultiHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py#L116)
312+
inside TensorFlow Model Garden.
313+
314+
First, as we rely on `tf.eisum` to define projections and attention computation,
315+
we need to figure out the einsum notation of each computation. Furthermore, to
316+
make the layer generalize to high-dimension cases, i.e. there are more than one
317+
batch dimensions and attention softmax can be performed on multiple axes, we
318+
need to track the batch axes and attention axes inside einsum notations. We use
319+
a vector of chars and use two local methods to generate einsum notations for
320+
projections and attentions.
321+
322+
Second, the layer by default implements the most common dot-product attention.
323+
There are various ways to implement the attention computation, so we modulize it
324+
as two methods `_build_attention` and `_compute_attention`. Thus, users may be
325+
able to just override them to get a new keras layer with a novel attention
326+
method. For example, we implemented
327+
[TalkingHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py)
328+
introduced by ["Talking-Heads Attention "](https://arxiv.org/abs/2003.02436)
329+
paper. Using the keras Attention layer as another example, since it supports the
330+
basic single-head case 1-D attention, we can use it inside `_build_attention`
331+
and `_compute_attention`.
332+
333+
## Questions and Discussion Topics
334+
335+
- cuDNN has the
336+
[multi-head attention](https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnMultiHeadAttnForward)
337+
function. How do we incorporate it? A: we modularize the attention
338+
computation components in order to support new low-level functions without
339+
changing this layer interface. The cuDNN function supports the classic
340+
dot-product attention with classic input dimensions. We will be able to use
341+
it once TensorFlow add an op to use it.

0 commit comments

Comments
 (0)