Skip to content

Commit 9c209e1

Browse files
committed
add layer_additive_attention()
1 parent e85cc81 commit 9c209e1

File tree

5 files changed

+95
-6
lines changed

5 files changed

+95
-6
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ export(layer_activation_softmax)
362362
export(layer_activation_thresholded_relu)
363363
export(layer_activity_regularization)
364364
export(layer_add)
365+
export(layer_additive_attention)
365366
export(layer_alpha_dropout)
366367
export(layer_attention)
367368
export(layer_average)

NEWS.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
To learn more, including how to make a custom cell layer, see the new vignette:
1515
"Working with RNNs".
1616

17+
- New layers:
18+
- `layer_additive_attention()`
19+
- `layer_conv_lstm_1d()`
20+
- `layer_conv_lstm_3d()`
21+
1722
- `layer_lstm()` default value for `recurrent_activation` changed from `"hard_sigmoid"` to `"sigmoid"`.
1823

19-
- New layers `layer_conv_lstm_1d()` and `layer_conv_lstm_3d()`.
20-
2124
- `layer_cudnn_gru()` and `layer_cudnn_lstm()` are deprecated. `layer_gru()` and `layer_lstm()` will
2225
automatically use CuDNN if it is available.
2326

R/layer-attention.R

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,41 @@ layer_multi_head_attention <- function(
131131
))
132132
}
133133

134-
# TODO: finish + document: https://www.tensorflow.org/api_docs/python/tf/keras/layers/AdditiveAttention
135-
layer_additive_attention <- function(object, use_scale=TRUE, ...) {
136-
args <- capture_args(match.call())
137-
create_layer(keras$layers$AdditiveAttention, object, args)
134+
135+
#' Additive attention layer, a.k.a. Bahdanau-style attention
136+
#'
137+
#' @details
138+
#' Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
139+
#' shape `[batch_size, Tv, dim]` and `key` tensor of shape
140+
#' `[batch_size, Tv, dim]`. The calculation follows the steps:
141+
#'
142+
#' 1. Reshape `query` and `key` into shapes `[batch_size, Tq, 1, dim]`
143+
#' and `[batch_size, 1, Tv, dim]` respectively.
144+
#' 2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear
145+
#' sum: `scores = tf.reduce_sum(tf.tanh(query + key), axis=-1)`
146+
#' 3. Use scores to calculate a distribution with shape
147+
#' `[batch_size, Tq, Tv]`: `distribution = tf$nn$softmax(scores)`.
148+
#' 4. Use `distribution` to create a linear combination of `value` with
149+
#' shape `[batch_size, Tq, dim]`:
150+
#' `return tf$matmul(distribution, value)`.
151+
#'
152+
#' @param use_scale If `TRUE`, will create a variable to scale the attention scores.
153+
#'
154+
#' @param causal Boolean. Set to `TRUE` for decoder self-attention. Adds a mask such
155+
#' that position `i` cannot attend to positions `j > i`. This prevents the
156+
#' flow of information from the future towards the past.
157+
#'
158+
#' @param dropout Float between 0 and 1. Fraction of the units to drop for the
159+
#' attention scores.
160+
#' @param ... standard layer arguments.
161+
#'
162+
#' @seealso
163+
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/AdditiveAttention>
164+
#' + <https://keras.io/api/layers/attention_layers/additive_attention/>
165+
#' @export
166+
layer_additive_attention <-
167+
function(object, use_scale = TRUE, ..., causal = FALSE, dropout = 0)
168+
{
169+
args <- capture_args(match.call(), NULL, ignore = "object")
170+
create_layer(keras$layers$AdditiveAttention, object, args)
138171
}

man/layer_additive_attention.Rd

Lines changed: 51 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkgdown/_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ reference:
248248
contents:
249249
- layer_attention
250250
- layer_multi_head_attention
251+
- layer_additive_attention
251252

252253
- title: "Layer Wrappers"
253254
contents:

0 commit comments

Comments
 (0)