Skip to content

Commit 78a4de6

Browse files
author
Sigrid Keydana
authored
Merge pull request #1021 from rstudio/update/pix2pix-example
(minimal) update for TF 2
2 parents 26f0b7f + 4692fb0 commit 78a4de6

File tree

1 file changed

+73
-73
lines changed

1 file changed

+73
-73
lines changed

vignettes/examples/eager_pix2pix.R

Lines changed: 73 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@
55
#' https://blogs.rstudio.com/tensorflow/posts/2018-09-20-eager-pix2pix
66

77
library(keras)
8-
use_implementation("tensorflow")
9-
108
library(tensorflow)
11-
12-
tfe_enable_eager_execution(device_policy = "silent")
9+
library(tfautograph)
1310

1411
library(tfdatasets)
1512
library(purrr)
1613

17-
restore <- TRUE
14+
restore <- FALSE
1815

1916
data_dir <- "facades"
2017

@@ -24,74 +21,83 @@ batches_per_epoch <- buffer_size / batch_size
2421
img_width <- 256L
2522
img_height <- 256L
2623

27-
load_image <- function(image_file, is_train) {
24+
w <- 512L
25+
w2 <- 256L
2826

29-
image <- tf$read_file(image_file)
27+
load <- function(image_file) {
28+
image <- tf$io$read_file(image_file)
3029
image <- tf$image$decode_jpeg(image)
31-
32-
w <- as.integer(k_shape(image)[2])
33-
w2 <- as.integer(w / 2L)
3430
real_image <- image[ , 1L:w2, ]
3531
input_image <- image[ , (w2 + 1L):w, ]
36-
3732
input_image <- k_cast(input_image, tf$float32)
3833
real_image <- k_cast(real_image, tf$float32)
39-
40-
if (is_train) {
41-
input_image <-
42-
tf$image$resize_images(input_image,
43-
c(286L, 286L),
44-
align_corners = TRUE,
45-
method = 2)
46-
real_image <- tf$image$resize_images(real_image,
47-
c(286L, 286L),
48-
align_corners = TRUE,
49-
method = 2)
50-
51-
stacked_image <-
52-
k_stack(list(input_image, real_image), axis = 1)
53-
cropped_image <-
54-
tf$random_crop(stacked_image, size = c(2L, img_height, img_width, 3L))
55-
c(input_image, real_image) %<-% list(cropped_image[1, , , ], cropped_image[2, , , ])
56-
57-
if (runif(1) > 0.5) {
58-
input_image <- tf$image$flip_left_right(input_image)
59-
real_image <- tf$image$flip_left_right(real_image)
60-
}
61-
} else {
62-
input_image <-
63-
tf$image$resize_images(
64-
input_image,
65-
size = c(img_height, img_width),
66-
align_corners = TRUE,
67-
method = 2
68-
)
69-
real_image <-
70-
tf$image$resize_images(
71-
real_image,
72-
size = c(img_height, img_width),
73-
align_corners = TRUE,
74-
method = 2
75-
)
76-
}
34+
list(input_image, real_image)
35+
}
36+
37+
resize <- function(input_image, real_image, height, width) {
38+
input_image <- tf$image$resize(input_image, list(height, width),
39+
method = tf$image$ResizeMethod$NEAREST_NEIGHBOR)
40+
real_image <- tf$image$resize(real_image, list(height, width),
41+
method = tf$image$ResizeMethod$NEAREST_NEIGHBOR)
42+
43+
list(input_image, real_image)
44+
}
7745

46+
random_crop <- function(input_image, real_image) {
47+
stacked_image <-
48+
k_stack(list(input_image, real_image), axis = 1)
49+
cropped_image <-
50+
tf$image$random_crop(stacked_image, size = c(2L, img_height, img_width, 3L))
51+
c(input_image, real_image) %<-% list(cropped_image[1, , , ], cropped_image[2, , , ])
52+
list(input_image, real_image)
53+
}
54+
55+
normalize <- function(input_image, real_image) {
7856
input_image <- (input_image / 127.5) - 1
7957
real_image <- (real_image / 127.5) - 1
8058

8159
list(input_image, real_image)
60+
61+
}
62+
63+
random_jitter <- function(input_image, real_image) {
64+
# resizing to 286 x 286 x 3
65+
c(input_image, real_image) %<-% resize(input_image, real_image, 286L, 286L)
66+
# randomly cropping to 256 x 256 x 3
67+
c(input_image, real_image) %<-% random_crop(input_image, real_image)
68+
if (tf$random$uniform(shape = list()) > 0.5) {
69+
# random mirroring
70+
input_image <- tf$image$flip_left_right(input_image)
71+
real_image <- tf$image$flip_left_right(real_image)
72+
}
73+
list(input_image, real_image)
74+
}
75+
76+
random_jitter <- tf_function(autograph(random_jitter))
77+
78+
load_image_train <- function(image_file) {
79+
c(input_image, real_image) %<-% load(image_file)
80+
c(input_image, real_image) %<-% random_jitter(input_image, real_image)
81+
#c(input_image, real_image) %<-% normalize(input_image, real_image)
82+
list(input_image, real_image)
83+
}
84+
85+
load_image_test <- function(image_file) {
86+
c(input_image, real_image) %<-% load(image_file)
87+
c(input_image, real_image) %<-% resize(input_image, real_image, img_height, img_width)
88+
c(input_image, real_image) %<-% normalize(input_image, real_image)
89+
list(input_image, real_image)
8290
}
8391

8492
train_dataset <-
8593
tf$data$Dataset$list_files(file.path(data_dir, "train/*.jpg")) %>%
8694
dataset_shuffle(buffer_size) %>%
87-
dataset_map(function(image)
88-
tf$py_func(load_image, list(image, TRUE), list(tf$float32, tf$float32))) %>%
95+
dataset_map(load_image_train, num_parallel_calls = 1) %>%
8996
dataset_batch(batch_size)
9097

9198
test_dataset <-
9299
tf$data$Dataset$list_files(file.path(data_dir, "test/*.jpg")) %>%
93-
dataset_map(function(image)
94-
tf$py_func(load_image, list(image, TRUE), list(tf$float32, tf$float32))) %>%
100+
dataset_map(load_image_test, num_parallel_calls = 1) %>%
95101
dataset_batch(batch_size)
96102

97103

@@ -120,7 +126,7 @@ downsample <- function(filters,
120126
if (self$apply_batchnorm) {
121127
x %>% self$batchnorm(training = training)
122128
}
123-
cat("downsample (generator) output: ", x$shape$as_list(), "\n")
129+
#cat("downsample (generator) output: ", x$shape$as_list(), "\n")
124130
x %>% layer_activation_leaky_relu()
125131
}
126132

@@ -155,7 +161,7 @@ upsample <- function(filters,
155161
}
156162
x %>% layer_activation("relu")
157163
concat <- k_concatenate(list(x, x2))
158-
cat("upsample (generator) output: ", concat$shape$as_list(), "\n")
164+
#cat("upsample (generator) output: ", concat$shape$as_list(), "\n")
159165
concat
160166
}
161167
})
@@ -217,7 +223,7 @@ generator <- function(name = "generator") {
217223
x15 <-
218224
self$up7(list(x14, x1), training = training) # (bs, 128, 128, 128)
219225
x16 <- self$last(x15) # (bs, 256, 256, 3)
220-
cat("generator output: ", x16$shape$as_list(), "\n")
226+
#cat("generator output: ", x16$shape$as_list(), "\n")
221227
x16
222228
}
223229
})
@@ -293,7 +299,7 @@ discriminator <- function(name = "discriminator") {
293299
layer_activation_leaky_relu() %>%
294300
self$zero_pad2() %>% # (bs, 33, 33, 512)
295301
self$last() # (bs, 30, 30, 1)
296-
cat("discriminator output: ", x$shape$as_list(), "\n")
302+
#cat("discriminator output: ", x$shape$as_list(), "\n")
297303
x
298304
}
299305
})
@@ -303,30 +309,24 @@ discriminator <- function(name = "discriminator") {
303309
generator <- generator()
304310
discriminator <- discriminator()
305311

306-
generator$call = tf$contrib$eager$defun(generator$call)
307-
discriminator$call = tf$contrib$eager$defun(discriminator$call)
312+
cross_entropy = tf$keras$losses$BinaryCrossentropy(from_logits = TRUE)
308313

309314
discriminator_loss <- function(real_output, generated_output) {
310-
real_loss <-
311-
tf$losses$sigmoid_cross_entropy(multi_class_labels = tf$ones_like(real_output),
312-
logits = real_output)
313-
generated_loss <-
314-
tf$losses$sigmoid_cross_entropy(multi_class_labels = tf$zeros_like(generated_output),
315-
logits = generated_output)
315+
real_loss <- cross_entropy(k_ones_like(real_output), real_output)
316+
generated_loss <- cross_entropy(k_zeros_like(generated_output), generated_output)
316317
real_loss + generated_loss
317318
}
318319

319320
lambda <- 100
320321
generator_loss <-
321322
function(disc_judgment, generated_output, target) {
322-
gan_loss <-
323-
tf$losses$sigmoid_cross_entropy(tf$ones_like(disc_judgment), disc_judgment)
323+
gan_loss <- cross_entropy(tf$ones_like(disc_judgment), disc_judgment)
324324
l1_loss <- tf$reduce_mean(tf$abs(target - generated_output))
325325
gan_loss + (lambda * l1_loss)
326326
}
327327

328-
discriminator_optimizer <- tf$train$AdamOptimizer(2e-4, beta1 = 0.5)
329-
generator_optimizer <- tf$train$AdamOptimizer(2e-4, beta1 = 0.5)
328+
discriminator_optimizer <- tf$optimizers$Adam(2e-4, beta_1 = 0.5)
329+
generator_optimizer <- tf$optimizers$Adam(2e-4, beta_1 = 0.5)
330330

331331
checkpoint_dir <- "./checkpoints_pix2pix"
332332
checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
@@ -387,17 +387,17 @@ train <- function(dataset, num_epochs) {
387387
})
388388
})
389389
generator_gradients <- gen_tape$gradient(gen_loss,
390-
generator$variables)
390+
generator$trainable_variables)
391391
discriminator_gradients <- disc_tape$gradient(disc_loss,
392-
discriminator$variables)
392+
discriminator$trainable_variables)
393393

394394
generator_optimizer$apply_gradients(transpose(list(
395395
generator_gradients,
396-
generator$variables
396+
generator$trainable_variables
397397
)))
398398
discriminator_optimizer$apply_gradients(transpose(
399399
list(discriminator_gradients,
400-
discriminator$variables)
400+
discriminator$trainable_variables)
401401
))
402402

403403
})

0 commit comments

Comments
 (0)