Skip to content

Commit 4812dff

Browse files
author
Sigrid Keydana
committed
factor out is_keras_tensor due to 2.0
1 parent 2d87e68 commit 4812dff

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

R/utils.R

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,13 @@ keras_array <- function(x, dtype = NULL) {
224224
if (
225225
tf_version() >= "1.12" &&
226226
(
227-
tensorflow::tf$contrib$framework$is_tensor(x) ||
228-
is.list(x) && all(vapply(x, tensorflow::tf$contrib$framework$is_tensor, logical(1)))
227+
is_keras_tensor(x) || is.list(x) && all(vapply(x, is_keras_tensor, logical(1)))
229228
)
230229
) {
231230
return(x)
232231
}
233232
} else {
234-
if ((keras_version() >= "2.2.0") && k_is_tensor(x)) {
233+
if ((keras_version() >= "2.2.0") && is_keras_tensor(x)) {
235234
return(x)
236235
}
237236
}
@@ -381,3 +380,11 @@ as_shape <- function(x) {
381380
as.integer(d)
382381
})
383382
}
383+
384+
is_keras_tensor <- function(x) {
385+
if (is_tensorflow_implementation()) {
386+
if (tensorflow::tf_version() >= "2.0") tensorflow::tf$is_tensor(x) else tensorflow::tf$contrib$framework$is_tensor(x)
387+
} else {
388+
k_is_tensor(x)
389+
}
390+
}

0 commit comments

Comments
 (0)