Skip to content

Commit 0e90135

Browse files
author
Sigrid Keydana
authored
Merge pull request #1023 from henry090/master
update examples
2 parents 425db71 + ff455f5 commit 0e90135

File tree

2 files changed

+27
-32
lines changed

2 files changed

+27
-32
lines changed

vignettes/examples/eager_cvae.R

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55
#' https://blogs.rstudio.com/tensorflow/posts/2018-10-22-mmd-vae/
66

77

8+
89
library(keras)
9-
use_implementation("tensorflow")
1010
library(tensorflow)
11-
tfe_enable_eager_execution(device_policy = "silent")
12-
1311
library(tfdatasets)
1412
library(dplyr)
1513
library(ggplot2)
1614
library(glue)
1715

18-
1916
# Setup and preprocessing -------------------------------------------------
2017

2118
fashion <- dataset_fashion_mnist()
@@ -24,9 +21,9 @@ c(train_images, train_labels) %<-% fashion$train
2421
c(test_images, test_labels) %<-% fashion$test
2522

2623
train_x <-
27-
train_images %>% `/`(255) %>% k_reshape(c(60000, 28, 28, 1))
24+
train_images %>% `/`(255) %>% k_reshape(c(60000, 28, 28, 1)) %>% k_cast(dtype='float32')
2825
test_x <-
29-
test_images %>% `/`(255) %>% k_reshape(c(10000, 28, 28, 1))
26+
test_images %>% `/`(255) %>% k_reshape(c(10000, 28, 28, 1)) %>% k_cast(dtype='float32')
3027

3128
class_names = c('T-shirt/top',
3229
'Trouser',
@@ -63,7 +60,8 @@ encoder_model <- function(name = NULL) {
6360
filters = 32,
6461
kernel_size = 3,
6562
strides = 2,
66-
activation = "relu"
63+
activation = "relu",
64+
dtype='float32'
6765
)
6866
self$conv2 <-
6967
layer_conv_2d(
@@ -126,22 +124,22 @@ decoder_model <- function(name = NULL) {
126124
}
127125

128126
reparameterize <- function(mean, logvar) {
129-
eps <- k_random_normal(shape = mean$shape, dtype = tf$float64)
127+
eps <- k_random_normal(shape = mean$shape, dtype = tf$float32)
130128
eps * k_exp(logvar * 0.5) + mean
131129
}
132130

133131

134132
# Loss and optimizer ------------------------------------------------------
135133

136134
normal_loglik <- function(sample, mean, logvar, reduce_axis = 2) {
137-
loglik <- k_constant(0.5, dtype = tf$float64) *
138-
(k_log(2 * k_constant(pi, dtype = tf$float64)) +
139-
logvar +
140-
k_exp(-logvar) * (sample - mean) ^ 2)
135+
loglik <- k_constant(0.5) *
136+
(k_log(2 * k_constant(pi)) +
137+
logvar +
138+
k_exp(-logvar) * (sample - mean) ^ 2)
141139
- k_sum(loglik, axis = reduce_axis)
142140
}
143141

144-
optimizer <- tf$train$AdamOptimizer(1e-4)
142+
optimizer <- tf$keras$optimizers$Adam(1e-4)
145143

146144

147145

@@ -151,7 +149,7 @@ num_examples_to_generate <- 64
151149

152150
random_vector_for_generation <-
153151
k_random_normal(shape = list(num_examples_to_generate, latent_dim),
154-
dtype = tf$float64)
152+
dtype = tf$float32)
155153

156154
generate_random_clothes <- function(epoch) {
157155
predictions <-
@@ -216,7 +214,7 @@ show_grid <- function(epoch) {
216214
z_sample <- matrix(c(grid_x[i], grid_y[j]), ncol = 2)
217215
column <-
218216
rbind(column,
219-
(decoder(z_sample) %>% tf$nn$sigmoid() %>% as.numeric()) %>% matrix(ncol = img_size))
217+
(decoder(k_cast(z_sample,'float32')) %>% tf$nn$sigmoid() %>% as.numeric()) %>% matrix(ncol = img_size))
220218
}
221219
rows <- cbind(rows, column)
222220
}
@@ -267,8 +265,8 @@ for (epoch in seq_len(num_epochs)) {
267265
-k_sum(crossentropy_loss)
268266
logpz <-
269267
normal_loglik(z,
270-
k_constant(0, dtype = tf$float64),
271-
k_constant(0, dtype = tf$float64))
268+
k_constant(0, dtype = tf$float32),
269+
k_constant(0, dtype = tf$float32))
272270
logqz_x <- normal_loglik(z, mean, logvar)
273271
loss <- -k_mean(logpx_z + logpz - logqz_x)
274272

@@ -284,12 +282,10 @@ for (epoch in seq_len(num_epochs)) {
284282

285283
optimizer$apply_gradients(purrr::transpose(list(
286284
encoder_gradients, encoder$variables
287-
)),
288-
global_step = tf$train$get_or_create_global_step())
285+
)))
289286
optimizer$apply_gradients(purrr::transpose(list(
290287
decoder_gradients, decoder$variables
291-
)),
292-
global_step = tf$train$get_or_create_global_step())
288+
)))
293289

294290
})
295291

@@ -313,3 +309,5 @@ for (epoch in seq_len(num_epochs)) {
313309
}
314310
}
315311

312+
313+

vignettes/examples/eager_styletransfer.R

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
library(keras)
99
use_implementation("tensorflow")
10-
use_session_with_seed(7777, disable_gpu = FALSE, disable_parallel_cpu = FALSE)
1110
library(tensorflow)
12-
tfe_enable_eager_execution(device_policy = "silent")
1311

1412
library(purrr)
1513
library(glue)
@@ -210,13 +208,12 @@ run_style_transfer <- function(content_path,
210208
gram_matrix(feature))
211209

212210
init_image <- load_and_process_image(content_path)
213-
init_image <-
214-
tf$contrib$eager$Variable(init_image, dtype = "float32")
211+
init_image <- tf$Variable(tf$cast(init_image,dtype = "float32"))
215212

216213
optimizer <-
217-
tf$train$AdamOptimizer(learning_rate = 1,
218-
beta1 = 0.99,
219-
epsilon = 1e-1)
214+
tf$optimizers$Adam(learning_rate = 1,
215+
beta_1 = 0.99,
216+
epsilon = 1e-1)
220217

221218
c(best_loss, best_image) %<-% list(Inf, NULL)
222219
loss_weights <- list(style_weight, content_weight)
@@ -242,16 +239,16 @@ run_style_transfer <- function(content_path,
242239

243240
end_time <- Sys.time()
244241

245-
if (k_cast_to_floatx(loss) < best_loss) {
246-
best_loss <- k_cast_to_floatx(loss)
242+
if (as.array(loss) < best_loss) {
243+
best_loss <- as.array(loss)
247244
best_image <- init_image
248245
}
249246

250247
if (i %% 50 == 0) {
251248
glue("Iteration: {i}") %>% print()
252249
glue(
253-
"Total loss: {k_cast_to_floatx(loss)}, style loss: {k_cast_to_floatx(style_score)},
254-
content loss: {k_cast_to_floatx(content_score)}, total variation loss: {k_cast_to_floatx(variation_score)},
250+
"Total loss: {as.array(loss)}, style loss: {as.array(style_score)},
251+
content loss: {as.array(content_score)}, total variation loss: {as.array(variation_score)},
255252
time for 1 iteration: {(Sys.time() - start_time) %>% round(2)}"
256253
) %>% print()
257254

0 commit comments

Comments
 (0)