55# ' https://blogs.rstudio.com/tensorflow/posts/2019-01-24-vq-vae/
66
77library(keras )
8- use_implementation(" tensorflow" )
98library(tensorflow )
10- tfe_enable_eager_execution(device_policy = " silent" )
11-
12- use_session_with_seed(7778 ,
13- disable_gpu = FALSE ,
14- disable_parallel_cpu = FALSE )
15-
16- tfp <- import(" tensorflow_probability" )
17- tfd <- tfp $ distributions
18-
9+ library(tfprobability )
1910library(tfdatasets )
11+
2012library(dplyr )
2113library(glue )
14+
15+ # curry has to be installed from github because CRAN version has no "set_defaults" function
16+ if (! (' devtools' %in% rownames(installed.packages()) )) {
17+ install.packages(' devtools' )
18+ }
19+ devtools :: install_github(' thomasp85/curry' )
20+
2221library(curry )
2322
2423moving_averages <- tf $ python $ training $ moving_averages
@@ -63,7 +62,14 @@ write_png <- function(dataset, epoch, desc, images) {
6362
6463np <- import(" numpy" )
6564
66- # download from: https://github.com/rois-codh/kmnist
65+ # download from: https://github.com/rois-codh/kmnist via "download_data()" function
66+ download_data = function (){
67+ if (! file.exists(' kmnist-train-imgs.npz' )) {
68+ download.file(' http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz' ,
69+ destfile = ' kmnist-train-imgs.npz' )
70+ }
71+ }
72+ download_data()
6773kuzushiji <- np $ load(" kmnist-train-imgs.npz" )
6874kuzushiji <- kuzushiji $ get(" arr_0" )
6975
@@ -90,7 +96,7 @@ batch %>% dim()
9096# Params ------------------------------------------------------------------
9197
9298learning_rate <- 0.001
93- latent_size <- 1
99+ latent_size <- 1L
94100num_codes <- 64L
95101code_size <- 16L
96102base_depth <- 32
@@ -214,7 +220,7 @@ decoder_model <- function(name = NULL,
214220 self $ deconv6() %> %
215221 # output shape: 7 28 28 1
216222 self $ conv1()
217- tfd $ Independent( tfd $ Bernoulli (logits = x ),
223+ tfd_independent(tfd_bernoulli (logits = x ),
218224 reinterpreted_batch_ndims = length(output_shape ))
219225 }
220226 })
@@ -228,16 +234,16 @@ vector_quantizer_model <-
228234 keras_model_custom(name = name , function (self ) {
229235 self $ num_codes <- num_codes
230236 self $ code_size <- code_size
231- self $ codebook <- tf $ get_variable(" codebook" ,
237+ self $ codebook <- tf $ compat $ v1 $ get_variable(" codebook" ,
232238 shape = c(num_codes , code_size ),
233239 dtype = tf $ float32 )
234- self $ ema_count <- tf $ get_variable(
240+ self $ ema_count <- tf $ compat $ v1 $ get_variable(
235241 name = " ema_count" ,
236242 shape = c(num_codes ),
237243 initializer = tf $ constant_initializer(0 ),
238244 trainable = FALSE
239245 )
240- self $ ema_means = tf $ get_variable(
246+ self $ ema_means = tf $ compat $ v1 $ get_variable(
241247 name = " ema_means" ,
242248 initializer = self $ codebook $ initialized_value(),
243249 trainable = FALSE
@@ -308,7 +314,7 @@ update_ema <- function(vector_quantizer,
308314 updated_ema_means <-
309315 updated_ema_means / tf $ expand_dims(updated_ema_count , axis = - 1L )
310316
311- tf $ assign(vector_quantizer $ codebook , updated_ema_means )
317+ tf $ compat $ v1 $ assign(vector_quantizer $ codebook , updated_ema_means )
312318}
313319
314320
@@ -321,7 +327,7 @@ decoder <- decoder_model(input_size = latent_size * code_size,
321327vector_quantizer <-
322328 vector_quantizer_model(num_codes = num_codes , code_size = code_size )
323329
324- optimizer <- tf $ train $ AdamOptimizer (learning_rate = learning_rate )
330+ optimizer <- tf $ optimizers $ Adam (learning_rate = learning_rate )
325331
326332checkpoint_dir <- " ./vq_vae_checkpoints"
327333
@@ -365,7 +371,7 @@ for (epoch in seq_len(num_epochs)) {
365371
366372 commitment_loss <- tf $ reduce_mean(tf $ square(codes - tf $ stop_gradient(nearest_codebook_entries )))
367373
368- prior_dist <- tfd $ Multinomial (total_count = 1 ,
374+ prior_dist <- tfd_multinomial (total_count = 1 ,
369375 logits = tf $ zeros(c(latent_size , num_codes )))
370376 prior_loss <- - tf $ reduce_mean(tf $ reduce_sum(prior_dist $ log_prob(one_hot_assignments ), 1L ))
371377
@@ -379,12 +385,10 @@ for (epoch in seq_len(num_epochs)) {
379385
380386 optimizer $ apply_gradients(purrr :: transpose(list (
381387 encoder_gradients , encoder $ variables
382- )),
383- global_step = tf $ train $ get_or_create_global_step())
388+ )))
384389 optimizer $ apply_gradients(purrr :: transpose(list (
385390 decoder_gradients , decoder $ variables
386- )),
387- global_step = tf $ train $ get_or_create_global_step())
391+ )))
388392
389393 update_ema(vector_quantizer ,
390394 one_hot_assignments ,
0 commit comments