Skip to content

Commit f42c7b7

Browse files
committed
"tf==2.0.0" compatibility & "tfprobability"
1 parent 56cf6a4 commit f42c7b7

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

vignettes/examples/vq_vae.R

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55
#' https://blogs.rstudio.com/tensorflow/posts/2019-01-24-vq-vae/
66

77
library(keras)
8-
use_implementation("tensorflow")
98
library(tensorflow)
10-
tfe_enable_eager_execution(device_policy = "silent")
11-
12-
use_session_with_seed(7778,
13-
disable_gpu = FALSE,
14-
disable_parallel_cpu = FALSE)
15-
16-
tfp <- import("tensorflow_probability")
17-
tfd <- tfp$distributions
18-
9+
library(tfprobability)
1910
library(tfdatasets)
11+
2012
library(dplyr)
2113
library(glue)
14+
15+
# curry has to be installed from github because CRAN version has no "set_defaults" function
16+
if(!('devtools' %in% rownames(installed.packages()) )) {
17+
install.packages('devtools')
18+
}
19+
devtools::install_github('thomasp85/curry')
20+
2221
library(curry)
2322

2423
moving_averages <- tf$python$training$moving_averages
@@ -63,7 +62,14 @@ write_png <- function(dataset, epoch, desc, images) {
6362

6463
np <- import("numpy")
6564

66-
# download from: https://github.com/rois-codh/kmnist
65+
# download from: https://github.com/rois-codh/kmnist via "download_data()" function
66+
download_data = function(){
67+
if(!file.exists('kmnist-train-imgs.npz')) {
68+
download.file('http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz',
69+
destfile = 'kmnist-train-imgs.npz')
70+
}
71+
}
72+
download_data()
6773
kuzushiji <- np$load("kmnist-train-imgs.npz")
6874
kuzushiji <- kuzushiji$get("arr_0")
6975

@@ -90,7 +96,7 @@ batch %>% dim()
9096
# Params ------------------------------------------------------------------
9197

9298
learning_rate <- 0.001
93-
latent_size <- 1
99+
latent_size <- 1L
94100
num_codes <- 64L
95101
code_size <- 16L
96102
base_depth <- 32
@@ -214,7 +220,7 @@ decoder_model <- function(name = NULL,
214220
self$deconv6() %>%
215221
# output shape: 7 28 28 1
216222
self$conv1()
217-
tfd$Independent(tfd$Bernoulli(logits = x),
223+
tfd_independent(tfd_bernoulli(logits = x),
218224
reinterpreted_batch_ndims = length(output_shape))
219225
}
220226
})
@@ -228,16 +234,16 @@ vector_quantizer_model <-
228234
keras_model_custom(name = name, function(self) {
229235
self$num_codes <- num_codes
230236
self$code_size <- code_size
231-
self$codebook <- tf$get_variable("codebook",
237+
self$codebook <- tf$compat$v1$get_variable("codebook",
232238
shape = c(num_codes, code_size),
233239
dtype = tf$float32)
234-
self$ema_count <- tf$get_variable(
240+
self$ema_count <- tf$compat$v1$get_variable(
235241
name = "ema_count",
236242
shape = c(num_codes),
237243
initializer = tf$constant_initializer(0),
238244
trainable = FALSE
239245
)
240-
self$ema_means = tf$get_variable(
246+
self$ema_means = tf$compat$v1$get_variable(
241247
name = "ema_means",
242248
initializer = self$codebook$initialized_value(),
243249
trainable = FALSE
@@ -308,7 +314,7 @@ update_ema <- function(vector_quantizer,
308314
updated_ema_means <-
309315
updated_ema_means / tf$expand_dims(updated_ema_count, axis = -1L)
310316

311-
tf$assign(vector_quantizer$codebook, updated_ema_means)
317+
tf$compat$v1$assign(vector_quantizer$codebook, updated_ema_means)
312318
}
313319

314320

@@ -321,7 +327,7 @@ decoder <- decoder_model(input_size = latent_size * code_size,
321327
vector_quantizer <-
322328
vector_quantizer_model(num_codes = num_codes, code_size = code_size)
323329

324-
optimizer <- tf$train$AdamOptimizer(learning_rate = learning_rate)
330+
optimizer <- tf$optimizers$Adam(learning_rate = learning_rate)
325331

326332
checkpoint_dir <- "./vq_vae_checkpoints"
327333

@@ -365,7 +371,7 @@ for (epoch in seq_len(num_epochs)) {
365371

366372
commitment_loss <- tf$reduce_mean(tf$square(codes - tf$stop_gradient(nearest_codebook_entries)))
367373

368-
prior_dist <- tfd$Multinomial(total_count = 1,
374+
prior_dist <- tfd_multinomial(total_count = 1,
369375
logits = tf$zeros(c(latent_size, num_codes)))
370376
prior_loss <- -tf$reduce_mean(tf$reduce_sum(prior_dist$log_prob(one_hot_assignments), 1L))
371377

@@ -379,12 +385,10 @@ for (epoch in seq_len(num_epochs)) {
379385

380386
optimizer$apply_gradients(purrr::transpose(list(
381387
encoder_gradients, encoder$variables
382-
)),
383-
global_step = tf$train$get_or_create_global_step())
388+
)))
384389
optimizer$apply_gradients(purrr::transpose(list(
385390
decoder_gradients, decoder$variables
386-
)),
387-
global_step = tf$train$get_or_create_global_step())
391+
)))
388392

389393
update_ema(vector_quantizer,
390394
one_hot_assignments,

0 commit comments

Comments
 (0)