Skip to content

Commit fee2fa0

Browse files
Turgut AbdullayevTurgut Abdullayev
authored andcommitted
compatibility with tf2
1 parent 425db71 commit fee2fa0

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

vignettes/examples/eager_cvae.R

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55
#' https://blogs.rstudio.com/tensorflow/posts/2018-10-22-mmd-vae/
66

77

8+
89
library(keras)
9-
use_implementation("tensorflow")
1010
library(tensorflow)
11-
tfe_enable_eager_execution(device_policy = "silent")
12-
1311
library(tfdatasets)
1412
library(dplyr)
1513
library(ggplot2)
1614
library(glue)
1715

18-
1916
# Setup and preprocessing -------------------------------------------------
2017

2118
fashion <- dataset_fashion_mnist()
@@ -24,9 +21,9 @@ c(train_images, train_labels) %<-% fashion$train
2421
c(test_images, test_labels) %<-% fashion$test
2522

2623
train_x <-
27-
train_images %>% `/`(255) %>% k_reshape(c(60000, 28, 28, 1))
24+
train_images %>% `/`(255) %>% k_reshape(c(60000, 28, 28, 1)) %>% k_cast(dtype='float32')
2825
test_x <-
29-
test_images %>% `/`(255) %>% k_reshape(c(10000, 28, 28, 1))
26+
test_images %>% `/`(255) %>% k_reshape(c(10000, 28, 28, 1)) %>% k_cast(dtype='float32')
3027

3128
class_names = c('T-shirt/top',
3229
'Trouser',
@@ -63,7 +60,8 @@ encoder_model <- function(name = NULL) {
6360
filters = 32,
6461
kernel_size = 3,
6562
strides = 2,
66-
activation = "relu"
63+
activation = "relu",
64+
dtype='float32'
6765
)
6866
self$conv2 <-
6967
layer_conv_2d(
@@ -88,7 +86,7 @@ encoder_model <- function(name = NULL) {
8886

8987
decoder_model <- function(name = NULL) {
9088
keras_model_custom(name = name, function(self) {
91-
self$dense <- layer_dense(units = 7 * 7 * 32, activation = "relu")
89+
self$dense <- layer_dense(units = 7 * 7 * 32, activation = "relu",dtype='float32')
9290
self$reshape <- layer_reshape(target_shape = c(7, 7, 32))
9391
self$deconv1 <-
9492
layer_conv_2d_transpose(
@@ -126,22 +124,22 @@ decoder_model <- function(name = NULL) {
126124
}
127125

128126
reparameterize <- function(mean, logvar) {
129-
eps <- k_random_normal(shape = mean$shape, dtype = tf$float64)
127+
eps <- k_random_normal(shape = mean$shape, dtype = tf$float32)
130128
eps * k_exp(logvar * 0.5) + mean
131129
}
132130

133131

134132
# Loss and optimizer ------------------------------------------------------
135133

136134
normal_loglik <- function(sample, mean, logvar, reduce_axis = 2) {
137-
loglik <- k_constant(0.5, dtype = tf$float64) *
138-
(k_log(2 * k_constant(pi, dtype = tf$float64)) +
139-
logvar +
140-
k_exp(-logvar) * (sample - mean) ^ 2)
135+
loglik <- k_constant(0.5, dtype = tf$float32) *
136+
(k_log(2 * k_constant(pi, dtype = tf$float32)) +
137+
logvar +
138+
k_exp(-logvar) * (sample - mean) ^ 2)
141139
- k_sum(loglik, axis = reduce_axis)
142140
}
143141

144-
optimizer <- tf$train$AdamOptimizer(1e-4)
142+
optimizer <- tf$keras$optimizers$Adam(1e-4)
145143

146144

147145

@@ -151,7 +149,7 @@ num_examples_to_generate <- 64
151149

152150
random_vector_for_generation <-
153151
k_random_normal(shape = list(num_examples_to_generate, latent_dim),
154-
dtype = tf$float64)
152+
dtype = tf$float32)
155153

156154
generate_random_clothes <- function(epoch) {
157155
predictions <-
@@ -267,8 +265,8 @@ for (epoch in seq_len(num_epochs)) {
267265
-k_sum(crossentropy_loss)
268266
logpz <-
269267
normal_loglik(z,
270-
k_constant(0, dtype = tf$float64),
271-
k_constant(0, dtype = tf$float64))
268+
k_constant(0, dtype = tf$float32),
269+
k_constant(0, dtype = tf$float32))
272270
logqz_x <- normal_loglik(z, mean, logvar)
273271
loss <- -k_mean(logpx_z + logpz - logqz_x)
274272

@@ -284,12 +282,10 @@ for (epoch in seq_len(num_epochs)) {
284282

285283
optimizer$apply_gradients(purrr::transpose(list(
286284
encoder_gradients, encoder$variables
287-
)),
288-
global_step = tf$train$get_or_create_global_step())
285+
)))
289286
optimizer$apply_gradients(purrr::transpose(list(
290287
decoder_gradients, decoder$variables
291-
)),
292-
global_step = tf$train$get_or_create_global_step())
288+
)))
293289

294290
})
295291

@@ -313,3 +309,5 @@ for (epoch in seq_len(num_epochs)) {
313309
}
314310
}
315311

312+
313+

0 commit comments

Comments
 (0)