55# ' https://blogs.rstudio.com/tensorflow/posts/2018-10-22-mmd-vae/
66
77
8+
89library(keras )
9- use_implementation(" tensorflow" )
1010library(tensorflow )
11- tfe_enable_eager_execution(device_policy = " silent" )
12-
1311library(tfdatasets )
1412library(dplyr )
1513library(ggplot2 )
1614library(glue )
1715
18-
1916# Setup and preprocessing -------------------------------------------------
2017
2118fashion <- dataset_fashion_mnist()
@@ -24,9 +21,9 @@ c(train_images, train_labels) %<-% fashion$train
2421c(test_images , test_labels ) %<- % fashion $ test
2522
2623train_x <-
27- train_images %> % `/`(255 ) %> % k_reshape(c(60000 , 28 , 28 , 1 ))
24+ train_images %> % `/`(255 ) %> % k_reshape(c(60000 , 28 , 28 , 1 )) % > % k_cast( dtype = ' float32 ' )
2825test_x <-
29- test_images %> % `/`(255 ) %> % k_reshape(c(10000 , 28 , 28 , 1 ))
26+ test_images %> % `/`(255 ) %> % k_reshape(c(10000 , 28 , 28 , 1 )) % > % k_cast( dtype = ' float32 ' )
3027
3128class_names = c(' T-shirt/top' ,
3229 ' Trouser' ,
@@ -63,7 +60,8 @@ encoder_model <- function(name = NULL) {
6360 filters = 32 ,
6461 kernel_size = 3 ,
6562 strides = 2 ,
66- activation = " relu"
63+ activation = " relu" ,
64+ dtype = ' float32'
6765 )
6866 self $ conv2 <-
6967 layer_conv_2d(
@@ -126,22 +124,22 @@ decoder_model <- function(name = NULL) {
126124}
127125
128126reparameterize <- function (mean , logvar ) {
129- eps <- k_random_normal(shape = mean $ shape , dtype = tf $ float64 )
127+ eps <- k_random_normal(shape = mean $ shape , dtype = tf $ float32 )
130128 eps * k_exp(logvar * 0.5 ) + mean
131129}
132130
133131
134132# Loss and optimizer ------------------------------------------------------
135133
136134normal_loglik <- function (sample , mean , logvar , reduce_axis = 2 ) {
137- loglik <- k_constant(0.5 , dtype = tf $ float64 ) *
138- (k_log(2 * k_constant(pi , dtype = tf $ float64 )) +
139- logvar +
140- k_exp(- logvar ) * (sample - mean ) ^ 2 )
135+ loglik <- k_constant(0.5 ) *
136+ (k_log(2 * k_constant(pi )) +
137+ logvar +
138+ k_exp(- logvar ) * (sample - mean ) ^ 2 )
141139 - k_sum(loglik , axis = reduce_axis )
142140}
143141
144- optimizer <- tf $ train $ AdamOptimizer (1e-4 )
142+ optimizer <- tf $ keras $ optimizers $ Adam (1e-4 )
145143
146144
147145
@@ -151,7 +149,7 @@ num_examples_to_generate <- 64
151149
152150random_vector_for_generation <-
153151 k_random_normal(shape = list (num_examples_to_generate , latent_dim ),
154- dtype = tf $ float64 )
152+ dtype = tf $ float32 )
155153
156154generate_random_clothes <- function (epoch ) {
157155 predictions <-
@@ -216,7 +214,7 @@ show_grid <- function(epoch) {
216214 z_sample <- matrix (c(grid_x [i ], grid_y [j ]), ncol = 2 )
217215 column <-
218216 rbind(column ,
219- (decoder(z_sample ) %> % tf $ nn $ sigmoid() %> % as.numeric()) %> % matrix (ncol = img_size ))
217+ (decoder(k_cast( z_sample , ' float32 ' ) ) %> % tf $ nn $ sigmoid() %> % as.numeric()) %> % matrix (ncol = img_size ))
220218 }
221219 rows <- cbind(rows , column )
222220 }
@@ -267,8 +265,8 @@ for (epoch in seq_len(num_epochs)) {
267265 - k_sum(crossentropy_loss )
268266 logpz <-
269267 normal_loglik(z ,
270- k_constant(0 , dtype = tf $ float64 ),
271- k_constant(0 , dtype = tf $ float64 ))
268+ k_constant(0 , dtype = tf $ float32 ),
269+ k_constant(0 , dtype = tf $ float32 ))
272270 logqz_x <- normal_loglik(z , mean , logvar )
273271 loss <- - k_mean(logpx_z + logpz - logqz_x )
274272
@@ -284,12 +282,10 @@ for (epoch in seq_len(num_epochs)) {
284282
285283 optimizer $ apply_gradients(purrr :: transpose(list (
286284 encoder_gradients , encoder $ variables
287- )),
288- global_step = tf $ train $ get_or_create_global_step())
285+ )))
289286 optimizer $ apply_gradients(purrr :: transpose(list (
290287 decoder_gradients , decoder $ variables
291- )),
292- global_step = tf $ train $ get_or_create_global_step())
288+ )))
293289
294290 })
295291
@@ -313,3 +309,5 @@ for (epoch in seq_len(num_epochs)) {
313309 }
314310}
315311
312+
313+
0 commit comments