@@ -131,8 +131,41 @@ layer_multi_head_attention <- function(
131131 ))
132132}
133133
134- # TODO: finish + document: https://www.tensorflow.org/api_docs/python/tf/keras/layers/AdditiveAttention
135- layer_additive_attention <- function (object , use_scale = TRUE , ... ) {
136- args <- capture_args(match.call())
137- create_layer(keras $ layers $ AdditiveAttention , object , args )
134+
135+ # ' Additive attention layer, a.k.a. Bahdanau-style attention
136+ # '
137+ # ' @details
138+ # ' Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
139+ # ' shape `[batch_size, Tv, dim]` and `key` tensor of shape
140+ # ' `[batch_size, Tv, dim]`. The calculation follows the steps:
141+ # '
142+ # ' 1. Reshape `query` and `key` into shapes `[batch_size, Tq, 1, dim]`
143+ # ' and `[batch_size, 1, Tv, dim]` respectively.
144+ # ' 2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear
145+ # ' sum: `scores = tf.reduce_sum(tf.tanh(query + key), axis=-1)`
146+ # ' 3. Use scores to calculate a distribution with shape
147+ # ' `[batch_size, Tq, Tv]`: `distribution = tf$nn$softmax(scores)`.
148+ # ' 4. Use `distribution` to create a linear combination of `value` with
149+ # ' shape `[batch_size, Tq, dim]`:
150+ # ' `return tf$matmul(distribution, value)`.
151+ # '
152+ # ' @param use_scale If `TRUE`, will create a variable to scale the attention scores.
153+ # '
154+ # ' @param causal Boolean. Set to `TRUE` for decoder self-attention. Adds a mask such
155+ # ' that position `i` cannot attend to positions `j > i`. This prevents the
156+ # ' flow of information from the future towards the past.
157+ # '
158+ # ' @param dropout Float between 0 and 1. Fraction of the units to drop for the
159+ # ' attention scores.
160+ # ' @param ... standard layer arguments.
161+ # '
162+ # ' @seealso
163+ # ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/AdditiveAttention>
164+ # ' + <https://keras.io/api/layers/attention_layers/additive_attention/>
165+ # ' @export
166+ layer_additive_attention <-
167+ function (object , use_scale = TRUE , ... , causal = FALSE , dropout = 0 )
168+ {
169+ args <- capture_args(match.call(), NULL , ignore = " object" )
170+ create_layer(keras $ layers $ AdditiveAttention , object , args )
138171}
0 commit comments