Skip to content

Commit 56cf6a4

Browse files
committed
"tfprobability" instead of "reticulate::import"
1 parent a411196 commit 56cf6a4

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

vignettes/examples/tfprob_vae.R

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@
55
#' https://blogs.rstudio.com/tensorflow/posts/2019-01-08-getting-started-with-tf-probability/
66

77
library(keras)
8-
use_implementation("tensorflow")
98
library(tensorflow)
10-
tfe_enable_eager_execution(device_policy = "silent")
11-
12-
tfp <- import("tensorflow_probability")
13-
tfd <- tfp$distributions
14-
9+
library(tfprobability)
1510
library(tfdatasets)
1611
library(dplyr)
1712
library(glue)
@@ -79,6 +74,17 @@ np <- import("numpy")
7974

8075
# assume data have been downloaded from https://github.com/rois-codh/kmnist
8176
# and stored in /tmp
77+
download_data = function(){
78+
if(!dir.exists('tmp')) {
79+
dir.create('tmp')
80+
}
81+
if(!file.exists('tmp/kmnist-train-imgs.npz')) {
82+
download.file('http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz',
83+
destfile = file.path("tmp", basename('kmnist-train-imgs.npz')))
84+
}
85+
}
86+
download_data()
87+
8288
kuzushiji <- np$load("/tmp/kmnist-train-imgs.npz")
8389
kuzushiji <- kuzushiji$get("arr_0")
8490

@@ -98,8 +104,8 @@ train_dataset <- tensor_slices_dataset(train_images) %>%
98104

99105
# Params ------------------------------------------------------------------
100106

101-
latent_dim <- 2
102-
mixture_components <- 16
107+
latent_dim <- 2L
108+
mixture_components <- 16L
103109

104110

105111
# Model -------------------------------------------------------------------
@@ -132,8 +138,8 @@ encoder_model <- function(name = NULL) {
132138
self$conv2() %>%
133139
self$flatten() %>%
134140
self$dense()
135-
tfd$MultivariateNormalDiag(loc = x[, 1:latent_dim],
136-
scale_diag = tf$nn$softplus(x[, (latent_dim + 1):(2 * latent_dim)] + 1e-5))
141+
tfd_multivariate_normal_diag(loc = x[, 1:latent_dim],
142+
scale_diag = tf$nn$softplus(x[, (latent_dim + 1):(2 * latent_dim)] + 1e-5))
137143
}
138144
})
139145
}
@@ -178,7 +184,7 @@ decoder_model <- function(name = NULL) {
178184
self$deconv2() %>%
179185
self$deconv3()
180186

181-
tfd$Independent(tfd$Bernoulli(logits = x),
187+
tfd_independent(tfd_bernoulli(logits = x),
182188
reinterpreted_batch_ndims = 3L)
183189

184190
}
@@ -192,30 +198,30 @@ learnable_prior_model <-
192198

193199
keras_model_custom(name = name, function(self) {
194200
self$loc <-
195-
tf$get_variable(
201+
tf$compat$v1$get_variable(
196202
name = "loc",
197203
shape = list(mixture_components, latent_dim),
198204
dtype = tf$float32
199205
)
200-
self$raw_scale_diag <- tf$get_variable(
206+
self$raw_scale_diag <- tf$compat$v1$get_variable(
201207
name = "raw_scale_diag",
202208
shape = c(mixture_components, latent_dim),
203209
dtype = tf$float32
204210
)
205211
self$mixture_logits <-
206-
tf$get_variable(
212+
tf$compat$v1$get_variable(
207213
name = "mixture_logits",
208214
shape = c(mixture_components),
209215
dtype = tf$float32
210216
)
211217

212218
function (x, mask = NULL) {
213-
tfd$MixtureSameFamily(
214-
components_distribution = tfd$MultivariateNormalDiag(
219+
tfd_mixture_same_family(
220+
components_distribution = tfd_multivariate_normal_diag(
215221
loc = self$loc,
216222
scale_diag = tf$nn$softplus(self$raw_scale_diag)
217223
),
218-
mixture_distribution = tfd$Categorical(logits = self$mixture_logits)
224+
mixture_distribution = tfd_categorical(logits = self$mixture_logits)
219225
)
220226
}
221227
})
@@ -234,8 +240,7 @@ compute_kl_loss <-
234240
}
235241

236242

237-
global_step <- tf$train$get_or_create_global_step()
238-
optimizer <- tf$train$AdamOptimizer(1e-4)
243+
optimizer <- tf$optimizers$Adam(1e-4)
239244

240245

241246
# Training loop -----------------------------------------------------------
@@ -253,7 +258,6 @@ checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
253258
checkpoint <-
254259
tf$train$Checkpoint(
255260
optimizer = optimizer,
256-
global_step = global_step,
257261
encoder = encoder,
258262
decoder = decoder,
259263
latent_prior_model = latent_prior_model
@@ -284,7 +288,7 @@ for (epoch in seq_len(num_epochs)) {
284288
compute_kl_loss(latent_prior,
285289
approx_posterior,
286290
approx_posterior_sample)
287-
291+
288292
loss <- kl_loss + avg_nll
289293
})
290294

@@ -299,18 +303,15 @@ for (epoch in seq_len(num_epochs)) {
299303

300304
optimizer$apply_gradients(purrr::transpose(list(
301305
encoder_gradients, encoder$variables
302-
)),
303-
global_step = tf$train$get_or_create_global_step())
306+
)))
304307
optimizer$apply_gradients(purrr::transpose(list(
305308
decoder_gradients, decoder$variables
306-
)),
307-
global_step = tf$train$get_or_create_global_step())
309+
)))
308310
optimizer$apply_gradients(purrr::transpose(list(
309311
prior_gradients, latent_prior_model$variables
310-
)),
311-
global_step = tf$train$get_or_create_global_step())
312+
)))
312313

313-
})
314+
})
314315

315316
checkpoint$save(file_prefix = checkpoint_prefix)
316317

@@ -329,3 +330,4 @@ for (epoch in seq_len(num_epochs)) {
329330
show_grid(epoch)
330331
}
331332
}
333+

0 commit comments

Comments
 (0)