Skip to content

Commit b1f375b

Browse files
authored
Merge pull request #1000 from atroiano/master
Adding Layer Attention
2 parents 63eddea + 41be25c commit b1f375b

16 files changed

+114
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ export(layer_activation_thresholded_relu)
302302
export(layer_activity_regularization)
303303
export(layer_add)
304304
export(layer_alpha_dropout)
305+
export(layer_attention)
305306
export(layer_average)
306307
export(layer_average_pooling_1d)
307308
export(layer_average_pooling_2d)

R/layer-attention.R

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
#' Creates attention layer
3+
#'
4+
#' Dot-product attention layer, a.k.a. Luong-style attention.
5+
#'
6+
#' @inheritParams layer_dense
7+
#'
8+
#' @param inputs a list of inputs first should be the query tensor, the second the value tensor
9+
#' @param use_scale If True, will create a scalar variable to scale the attention scores.
10+
#' @param causal Boolean. Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i.
11+
#' This prevents the flow of information from the future towards the past.
12+
#'
13+
#' @family core layers
14+
#' @family attention layers
15+
#'
16+
#' @export
17+
layer_attention <- function(inputs,use_scale=FALSE, causal = FALSE, batch_size = NULL, dtype = NULL,
18+
name = NULL, trainable = NULL, weights = NULL) {
19+
if (!is_tensorflow_implementation() || !tensorflow::tf_version() >= "1.14")
20+
stop("layer_dense_features requires TensorFlow implementation and version >= 1.14")
21+
create_layer(keras$layers$Attention, inputs, list(
22+
use_scale = use_scale,
23+
causal = causal,
24+
batch_size = batch_size,
25+
dtype = dtype,
26+
name = name,
27+
trainable = trainable,
28+
weights = weights)
29+
)
30+
31+
32+
}
33+
34+
35+

man/layer_activation.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_activity_regularization.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_attention.Rd

Lines changed: 58 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_dense.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_dense_features.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_dropout.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_flatten.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_input.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)