Skip to content

Commit 1db416b

Browse files
committed
added check for TF implementation
updated test case
1 parent 7a8ce05 commit 1db416b

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

R/layer-attention.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#' @export
1717
layer_attention <- function(inputs,use_scale=FALSE, causal = FALSE, batch_size = NULL, dtype = NULL,
1818
name = NULL, trainable = NULL, weights = NULL) {
19+
if (!is_tensorflow_implementation() || !tensorflow::tf_version() >= "1.14")
20+
stop("layer_dense_features requires TensorFlow implementation and version >= 1.14")
1921
create_layer(keras$layers$Attention, inputs, list(
2022
use_scale = use_scale,
2123
causal = causal,

tests/testthat/test-layers.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,12 @@ test_call_succeeds("layer_separable_conv_1d", required_version = "2.1.3", {
498498
}
499499
})
500500

501-
test_call_succeeds('layer_attention',{
502-
input_1 = layer_input(shape=c(4,5))
503-
input_2 = layer_input(shape=c(4,5))
504-
layer_attention(c(input_1,input_2))
501+
test_call_succeeds('layer_attention'{
502+
if (is_tensorflow_implementation() && tensorflow::tf_version() >= "1.14"){
503+
input_1 = layer_input(shape=c(4,5))
504+
input_2 = layer_input(shape=c(4,5))
505+
layer_attention(c(input_1,input_2))
506+
}
505507
})
506508

507509
test_call_succeeds("layer_dense_features", required_version = "2.1.3", {

0 commit comments

Comments
 (0)