Skip to content

Commit b42a816

Browse files
committed
Added attention layer function
Updated test cases to test basic useage
1 parent 63eddea commit b42a816

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

R/layers_attention.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
#' Applies Dropout to the input.
3+
#'
4+
#' Dropout consists in randomly setting a fraction `rate` of input units to 0 at
5+
#' each update during training time, which helps prevent overfitting.
6+
#'
7+
#' @inheritParams layer_dense
8+
#' @param inputs a list of inputs first should be the query tensor, the second the value tensor
9+
#' @param use_scale If True, will create a scalar variable to scale the attention scores.
10+
#' @param causal Boolean. Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i.
11+
#' This prevents the flow of information from the future towards the past.
12+
#'
13+
#' @family core layers
14+
#' @family attention layers
15+
#'
16+
#' @export
17+
layer_attention <- function(inputs, use_scale=FALSE, causal = FALSE) {
18+
19+
create_layer(tf$keras$layers$Attention,object = inputs,args = list(use_scale = use_scale, causal = causal) )
20+
}
21+
22+
23+

tests/testthat/test-layers.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,12 @@ test_call_succeeds("layer_separable_conv_1d", required_version = "2.1.3", {
498498
}
499499
})
500500

501+
test_call_succeeds('layer_attention',{
502+
input_1 = layer_input(shape=c(4,5))
503+
input_2 = layer_input(shape=c(4,5))
504+
layer_attention(c(input_1,input_2))
505+
})
506+
501507
test_call_succeeds("layer_dense_features", required_version = "2.1.3", {
502508
if (is_tensorflow_implementation() && tensorflow::tf_version() >= "1.14") {
503509

0 commit comments

Comments
 (0)