55# ' https://blogs.rstudio.com/tensorflow/posts/2019-01-08-getting-started-with-tf-probability/
66
77library(keras )
8- use_implementation(" tensorflow" )
98library(tensorflow )
10- tfe_enable_eager_execution(device_policy = " silent" )
11-
12- tfp <- import(" tensorflow_probability" )
13- tfd <- tfp $ distributions
14-
9+ library(tfprobability )
1510library(tfdatasets )
1611library(dplyr )
1712library(glue )
@@ -79,6 +74,17 @@ np <- import("numpy")
7974
8075# assume data have been downloaded from https://github.com/rois-codh/kmnist
8176# and stored in /tmp
77+ download_data = function (){
78+ if (! dir.exists(' tmp' )) {
79+ dir.create(' tmp' )
80+ }
81+ if (! file.exists(' tmp/kmnist-train-imgs.npz' )) {
82+ download.file(' http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz' ,
83+ destfile = file.path(" tmp" , basename(' kmnist-train-imgs.npz' )))
84+ }
85+ }
86+ download_data()
87+
8288kuzushiji <- np $ load(" /tmp/kmnist-train-imgs.npz" )
8389kuzushiji <- kuzushiji $ get(" arr_0" )
8490
@@ -98,8 +104,8 @@ train_dataset <- tensor_slices_dataset(train_images) %>%
98104
99105# Params ------------------------------------------------------------------
100106
101- latent_dim <- 2
102- mixture_components <- 16
107+ latent_dim <- 2L
108+ mixture_components <- 16L
103109
104110
105111# Model -------------------------------------------------------------------
@@ -132,8 +138,8 @@ encoder_model <- function(name = NULL) {
132138 self $ conv2() %> %
133139 self $ flatten() %> %
134140 self $ dense()
135- tfd $ MultivariateNormalDiag (loc = x [, 1 : latent_dim ],
136- scale_diag = tf $ nn $ softplus(x [, (latent_dim + 1 ): (2 * latent_dim )] + 1e-5 ))
141+ tfd_multivariate_normal_diag (loc = x [, 1 : latent_dim ],
142+ scale_diag = tf $ nn $ softplus(x [, (latent_dim + 1 ): (2 * latent_dim )] + 1e-5 ))
137143 }
138144 })
139145}
@@ -178,7 +184,7 @@ decoder_model <- function(name = NULL) {
178184 self $ deconv2() %> %
179185 self $ deconv3()
180186
181- tfd $ Independent( tfd $ Bernoulli (logits = x ),
187+ tfd_independent(tfd_bernoulli (logits = x ),
182188 reinterpreted_batch_ndims = 3L )
183189
184190 }
@@ -192,30 +198,30 @@ learnable_prior_model <-
192198
193199 keras_model_custom(name = name , function (self ) {
194200 self $ loc <-
195- tf $ get_variable(
201+ tf $ compat $ v1 $ get_variable(
196202 name = " loc" ,
197203 shape = list (mixture_components , latent_dim ),
198204 dtype = tf $ float32
199205 )
200- self $ raw_scale_diag <- tf $ get_variable(
206+ self $ raw_scale_diag <- tf $ compat $ v1 $ get_variable(
201207 name = " raw_scale_diag" ,
202208 shape = c(mixture_components , latent_dim ),
203209 dtype = tf $ float32
204210 )
205211 self $ mixture_logits <-
206- tf $ get_variable(
212+ tf $ compat $ v1 $ get_variable(
207213 name = " mixture_logits" ,
208214 shape = c(mixture_components ),
209215 dtype = tf $ float32
210216 )
211217
212218 function (x , mask = NULL ) {
213- tfd $ MixtureSameFamily (
214- components_distribution = tfd $ MultivariateNormalDiag (
219+ tfd_mixture_same_family (
220+ components_distribution = tfd_multivariate_normal_diag (
215221 loc = self $ loc ,
216222 scale_diag = tf $ nn $ softplus(self $ raw_scale_diag )
217223 ),
218- mixture_distribution = tfd $ Categorical (logits = self $ mixture_logits )
224+ mixture_distribution = tfd_categorical (logits = self $ mixture_logits )
219225 )
220226 }
221227 })
@@ -234,8 +240,7 @@ compute_kl_loss <-
234240 }
235241
236242
237- global_step <- tf $ train $ get_or_create_global_step()
238- optimizer <- tf $ train $ AdamOptimizer(1e-4 )
243+ optimizer <- tf $ optimizers $ Adam(1e-4 )
239244
240245
241246# Training loop -----------------------------------------------------------
@@ -253,7 +258,6 @@ checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
253258checkpoint <-
254259 tf $ train $ Checkpoint(
255260 optimizer = optimizer ,
256- global_step = global_step ,
257261 encoder = encoder ,
258262 decoder = decoder ,
259263 latent_prior_model = latent_prior_model
@@ -284,7 +288,7 @@ for (epoch in seq_len(num_epochs)) {
284288 compute_kl_loss(latent_prior ,
285289 approx_posterior ,
286290 approx_posterior_sample )
287-
291+
288292 loss <- kl_loss + avg_nll
289293 })
290294
@@ -299,18 +303,15 @@ for (epoch in seq_len(num_epochs)) {
299303
300304 optimizer $ apply_gradients(purrr :: transpose(list (
301305 encoder_gradients , encoder $ variables
302- )),
303- global_step = tf $ train $ get_or_create_global_step())
306+ )))
304307 optimizer $ apply_gradients(purrr :: transpose(list (
305308 decoder_gradients , decoder $ variables
306- )),
307- global_step = tf $ train $ get_or_create_global_step())
309+ )))
308310 optimizer $ apply_gradients(purrr :: transpose(list (
309311 prior_gradients , latent_prior_model $ variables
310- )),
311- global_step = tf $ train $ get_or_create_global_step())
312+ )))
312313
313- })
314+ })
314315
315316 checkpoint $ save(file_prefix = checkpoint_prefix )
316317
@@ -329,3 +330,4 @@ for (epoch in seq_len(num_epochs)) {
329330 show_grid(epoch )
330331 }
331332}
333+
0 commit comments