Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit b3daa1d

Browse files
authored
Update 20200616-keras-multihead-attention.md
1 parent 33b20f1 commit b3daa1d

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

rfcs/20200616-keras-multihead-attention.md

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
| Status | (Proposed / Accepted / Implemented / Obsolete) |
44
| :------------ | :------------------------------------------------------ |
55
| **RFC #** | [260](https://github.com/tensorflow/community/pull/260) |
6-
| **Author(s)** | Hongkun Yu ([email protected]), Mark Omernick ([email protected]) |
6+
| **Author(s)** | Hongkun Yu ([email protected]), Mark Omernick ([email protected]) |
77
| **Sponsor** | Francois Chollet ([email protected]) |
88
| **Updated** | 2020-06-16 |
99

@@ -83,7 +83,7 @@ test_layer = MultiHeadAttention(
8383
num_heads=2, key_size=2, return_attention_scores=True)
8484
target = np.array([[[0.1, 0.2], [0.0, 0.0]]])
8585
source = np.array([[[0.1, 0.2], [3.0, 1.0]]])
86-
output, scores = test_layer([target, source])
86+
output, scores = test_layer(query=target, value=source)
8787
scores = tf.math.reduce_sum(scores, axis=1) # shape = (1, 2, 2)
8888
```
8989

@@ -96,7 +96,7 @@ mask_shape = [2, 3, 4, 3, 2]
9696
query = 10 * np.random.random_sample(query_shape)
9797
value = 10 * np.random.random_sample(value_shape)
9898
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
99-
output = test_layer([query, value], mask_data)
99+
output = test_layer(query=query, value=value, attention_mask=mask_data)
100100
```
101101

102102
### Interface
@@ -134,15 +134,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
134134
>>> target = tf.keras.Input(shape=[8, 16])
135135
>>> source = tf.keras.Input(shape=[4, 16])
136136
>>> mask_tensor = tf.keras.Input(shape=[8, 4])
137-
>>> output_tensor, weights = layer([target, source], attention_mask=mask_tensor)
137+
>>> output_tensor, weights = layer(query=target, value=source
138+
... attention_mask=mask_tensor)
138139
>>> print(output_tensor.shape), print(weights.shape)
139140
(None, 8, 16) (None, 2, 8, 4)
140141
141142
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
142143
143144
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3))
144145
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
145-
>>> output_tensor = layer([input_tensor, input_tensor])
146+
>>> output_tensor = layer(query=input_tensor, value=input_tensor)
146147
>>> print(output_tensor.shape)
147148
(None, 5, 3, 4, 16)
148149
@@ -167,7 +168,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
167168
bias_constraint: Constraint for dense layer kernels.
168169
"""
169170

170-
def call(self, inputs, attention_mask=None):
171+
def call(self, query, value, key=None, attention_mask=None):
171172
"""Implements the forward pass.
172173
173174
Size glossary:
@@ -180,10 +181,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
180181
* Value (source) attention axes shape (S), the rank must match the target.
181182
182183
Args:
183-
inputs: List of the following tensors:
184-
* query: Query `Tensor` of shape `[B, T, dim]`.
185-
* value: Value `Tensor` of shape `[B, S, dim]`.
186-
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
184+
query: Query `Tensor` of shape `[B, T, dim]`.
185+
value: Value `Tensor` of shape `[B, S, dim]`.
186+
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
187187
use `value` for both `key` and `value`, which is the most common case.
188188
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
189189
attention to certain positions.
@@ -242,14 +242,15 @@ we would like to introduce an optional argument `attention_mask` for
242242
the shape is (batch_size, target_length, source_length). Whenever
243243
`attention_mask` is specified, the `mask` argument is OK to be skipped.
244244

245-
* TFA `MultiHeadAttention` Deprecation and Re-mapping
246-
247-
[MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py) has been released. The proposed `MultiHeadAttention` has similar `__init__` arguments
248-
and `call` interface, where the minor differences are argument names and the attention `mask` shape.
249-
We expect the new `MultiHeadAttention` keras layer will
250-
cover the functionalities. Once the implementation are merged as experimental layers,
251-
we will work with TF Addons team to design the deprecation and re-mapping procedure.
245+
* TFA `MultiHeadAttention` Deprecation and Re-mapping
252246

247+
[MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py)
248+
has been released. The proposed `MultiHeadAttention` has similar `__init__`
249+
arguments and `call` interface, where the minor differences are argument names
250+
and the attention `mask` shape. We expect the new `MultiHeadAttention` keras
251+
layer will cover the functionalities. Once the implementation are merged as
252+
experimental layers, we will work with TF Addons team to design the deprecation
253+
and re-mapping procedure.
253254

254255
### Alternatives Considered
255256

@@ -307,33 +308,31 @@ No dependencies.
307308
and serializable to SavedModel. These are tested inside TensorFlow Model
308309
Garden applications.
309310

310-
### User Impact
311+
### User Impacteisum
311312

312313
* We will first introduce the layer as
313314
`tf.keras.layers.experimental.MultiHeadAttention` and
314315
`tf.keras.layers.experimental.EinsumDense`. When the APIs are stable and
315-
functionalities are fully verified, the next step is to
316-
graduate as core keras layers by removing `experimental` scope.
317-
318-
316+
functionalities are fully verified, the next step is to graduate as core
317+
keras layers by removing `experimental` scope.
319318

320319
## Detailed Design
321320

322321
The layer has been implemented as the
323322
[MultiHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py#L116)
324323
inside TensorFlow Model Garden.
325324

326-
First, as we rely on `tf.eisum` to define projections and attention computation,
327-
we need to figure out the einsum notation of each computation. Furthermore, to
328-
make the layer generalize to high-dimension cases, i.e. there are more than one
329-
batch dimensions and attention softmax can be performed on multiple axes, we
330-
need to track the batch axes and attention axes inside einsum notations. We use
331-
a vector of chars and use two local methods to generate einsum notations for
332-
projections and attentions.
325+
First, as we rely on `tf.einsum` to define projections and attention
326+
computation, we need to figure out the einsum notation of each computation.
327+
Furthermore, to make the layer generalize to high-dimension cases, i.e. there
328+
are more than one batch dimensions and attention softmax can be performed on
329+
multiple axes, we need to track the batch axes and attention axes inside einsum
330+
notations. We use a vector of chars and use two local methods to generate einsum
331+
notations for projections and attentions.
333332

334333
Second, the layer by default implements the most common dot-product attention.
335334
There are various ways to implement the attention computation, so we modulize it
336-
as two methods `_build_attention` and `_compute_attention`. Thus, users may be
335+
as two methods `build_attention` and `compute_attention`. Thus, users will be
337336
able to just override them to get a new keras layer with a novel attention
338337
method. For example, we implemented
339338
[TalkingHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py)

0 commit comments

Comments
 (0)