55# ' https://blogs.rstudio.com/tensorflow/posts/2018-09-20-eager-pix2pix
66
77library(keras )
8- use_implementation(" tensorflow" )
9-
108library(tensorflow )
11-
12- tfe_enable_eager_execution(device_policy = " silent" )
9+ library(tfautograph )
1310
1411library(tfdatasets )
1512library(purrr )
1613
17- restore <- TRUE
14+ restore <- FALSE
1815
1916data_dir <- " facades"
2017
@@ -24,74 +21,83 @@ batches_per_epoch <- buffer_size / batch_size
2421img_width <- 256L
2522img_height <- 256L
2623
27- load_image <- function (image_file , is_train ) {
24+ w <- 512L
25+ w2 <- 256L
2826
29- image <- tf $ read_file(image_file )
27+ load <- function (image_file ) {
28+ image <- tf $ io $ read_file(image_file )
3029 image <- tf $ image $ decode_jpeg(image )
31-
32- w <- as.integer(k_shape(image )[2 ])
33- w2 <- as.integer(w / 2L )
3430 real_image <- image [ , 1L : w2 , ]
3531 input_image <- image [ , (w2 + 1L ): w , ]
36-
3732 input_image <- k_cast(input_image , tf $ float32 )
3833 real_image <- k_cast(real_image , tf $ float32 )
39-
40- if (is_train ) {
41- input_image <-
42- tf $ image $ resize_images(input_image ,
43- c(286L , 286L ),
44- align_corners = TRUE ,
45- method = 2 )
46- real_image <- tf $ image $ resize_images(real_image ,
47- c(286L , 286L ),
48- align_corners = TRUE ,
49- method = 2 )
50-
51- stacked_image <-
52- k_stack(list (input_image , real_image ), axis = 1 )
53- cropped_image <-
54- tf $ random_crop(stacked_image , size = c(2L , img_height , img_width , 3L ))
55- c(input_image , real_image ) %<- % list (cropped_image [1 , , , ], cropped_image [2 , , , ])
56-
57- if (runif(1 ) > 0.5 ) {
58- input_image <- tf $ image $ flip_left_right(input_image )
59- real_image <- tf $ image $ flip_left_right(real_image )
60- }
61- } else {
62- input_image <-
63- tf $ image $ resize_images(
64- input_image ,
65- size = c(img_height , img_width ),
66- align_corners = TRUE ,
67- method = 2
68- )
69- real_image <-
70- tf $ image $ resize_images(
71- real_image ,
72- size = c(img_height , img_width ),
73- align_corners = TRUE ,
74- method = 2
75- )
76- }
34+ list (input_image , real_image )
35+ }
36+
37+ resize <- function (input_image , real_image , height , width ) {
38+ input_image <- tf $ image $ resize(input_image , list (height , width ),
39+ method = tf $ image $ ResizeMethod $ NEAREST_NEIGHBOR )
40+ real_image <- tf $ image $ resize(real_image , list (height , width ),
41+ method = tf $ image $ ResizeMethod $ NEAREST_NEIGHBOR )
42+
43+ list (input_image , real_image )
44+ }
7745
46+ random_crop <- function (input_image , real_image ) {
47+ stacked_image <-
48+ k_stack(list (input_image , real_image ), axis = 1 )
49+ cropped_image <-
50+ tf $ image $ random_crop(stacked_image , size = c(2L , img_height , img_width , 3L ))
51+ c(input_image , real_image ) %<- % list (cropped_image [1 , , , ], cropped_image [2 , , , ])
52+ list (input_image , real_image )
53+ }
54+
55+ normalize <- function (input_image , real_image ) {
7856 input_image <- (input_image / 127.5 ) - 1
7957 real_image <- (real_image / 127.5 ) - 1
8058
8159 list (input_image , real_image )
60+
61+ }
62+
63+ random_jitter <- function (input_image , real_image ) {
64+ # resizing to 286 x 286 x 3
65+ c(input_image , real_image ) %<- % resize(input_image , real_image , 286L , 286L )
66+ # randomly cropping to 256 x 256 x 3
67+ c(input_image , real_image ) %<- % random_crop(input_image , real_image )
68+ if (tf $ random $ uniform(shape = list ()) > 0.5 ) {
69+ # random mirroring
70+ input_image <- tf $ image $ flip_left_right(input_image )
71+ real_image <- tf $ image $ flip_left_right(real_image )
72+ }
73+ list (input_image , real_image )
74+ }
75+
76+ random_jitter <- tf_function(autograph(random_jitter ))
77+
78+ load_image_train <- function (image_file ) {
79+ c(input_image , real_image ) %<- % load(image_file )
80+ c(input_image , real_image ) %<- % random_jitter(input_image , real_image )
81+ # c(input_image, real_image) %<-% normalize(input_image, real_image)
82+ list (input_image , real_image )
83+ }
84+
85+ load_image_test <- function (image_file ) {
86+ c(input_image , real_image ) %<- % load(image_file )
87+ c(input_image , real_image ) %<- % resize(input_image , real_image , img_height , img_width )
88+ c(input_image , real_image ) %<- % normalize(input_image , real_image )
89+ list (input_image , real_image )
8290}
8391
8492train_dataset <-
8593 tf $ data $ Dataset $ list_files(file.path(data_dir , " train/*.jpg" )) %> %
8694 dataset_shuffle(buffer_size ) %> %
87- dataset_map(function (image )
88- tf $ py_func(load_image , list (image , TRUE ), list (tf $ float32 , tf $ float32 ))) %> %
95+ dataset_map(load_image_train , num_parallel_calls = 1 ) %> %
8996 dataset_batch(batch_size )
9097
9198test_dataset <-
9299 tf $ data $ Dataset $ list_files(file.path(data_dir , " test/*.jpg" )) %> %
93- dataset_map(function (image )
94- tf $ py_func(load_image , list (image , TRUE ), list (tf $ float32 , tf $ float32 ))) %> %
100+ dataset_map(load_image_test , num_parallel_calls = 1 ) %> %
95101 dataset_batch(batch_size )
96102
97103
@@ -120,7 +126,7 @@ downsample <- function(filters,
120126 if (self $ apply_batchnorm ) {
121127 x %> % self $ batchnorm(training = training )
122128 }
123- cat(" downsample (generator) output: " , x $ shape $ as_list(), " \n " )
129+ # cat("downsample (generator) output: ", x$shape$as_list(), "\n")
124130 x %> % layer_activation_leaky_relu()
125131 }
126132
@@ -155,7 +161,7 @@ upsample <- function(filters,
155161 }
156162 x %> % layer_activation(" relu" )
157163 concat <- k_concatenate(list (x , x2 ))
158- cat(" upsample (generator) output: " , concat $ shape $ as_list(), " \n " )
164+ # cat("upsample (generator) output: ", concat$shape$as_list(), "\n")
159165 concat
160166 }
161167 })
@@ -217,7 +223,7 @@ generator <- function(name = "generator") {
217223 x15 <-
218224 self $ up7(list (x14 , x1 ), training = training ) # (bs, 128, 128, 128)
219225 x16 <- self $ last(x15 ) # (bs, 256, 256, 3)
220- cat(" generator output: " , x16 $ shape $ as_list(), " \n " )
226+ # cat("generator output: ", x16$shape$as_list(), "\n")
221227 x16
222228 }
223229 })
@@ -293,7 +299,7 @@ discriminator <- function(name = "discriminator") {
293299 layer_activation_leaky_relu() %> %
294300 self $ zero_pad2() %> % # (bs, 33, 33, 512)
295301 self $ last() # (bs, 30, 30, 1)
296- cat(" discriminator output: " , x $ shape $ as_list(), " \n " )
302+ # cat("discriminator output: ", x$shape$as_list(), "\n")
297303 x
298304 }
299305 })
@@ -303,30 +309,24 @@ discriminator <- function(name = "discriminator") {
303309generator <- generator()
304310discriminator <- discriminator()
305311
306- generator $ call = tf $ contrib $ eager $ defun(generator $ call )
307- discriminator $ call = tf $ contrib $ eager $ defun(discriminator $ call )
312+ cross_entropy = tf $ keras $ losses $ BinaryCrossentropy(from_logits = TRUE )
308313
309314discriminator_loss <- function (real_output , generated_output ) {
310- real_loss <-
311- tf $ losses $ sigmoid_cross_entropy(multi_class_labels = tf $ ones_like(real_output ),
312- logits = real_output )
313- generated_loss <-
314- tf $ losses $ sigmoid_cross_entropy(multi_class_labels = tf $ zeros_like(generated_output ),
315- logits = generated_output )
315+ real_loss <- cross_entropy(k_ones_like(real_output ), real_output )
316+ generated_loss <- cross_entropy(k_zeros_like(generated_output ), generated_output )
316317 real_loss + generated_loss
317318}
318319
319320lambda <- 100
320321generator_loss <-
321322 function (disc_judgment , generated_output , target ) {
322- gan_loss <-
323- tf $ losses $ sigmoid_cross_entropy(tf $ ones_like(disc_judgment ), disc_judgment )
323+ gan_loss <- cross_entropy(tf $ ones_like(disc_judgment ), disc_judgment )
324324 l1_loss <- tf $ reduce_mean(tf $ abs(target - generated_output ))
325325 gan_loss + (lambda * l1_loss )
326326 }
327327
328- discriminator_optimizer <- tf $ train $ AdamOptimizer (2e-4 , beta1 = 0.5 )
329- generator_optimizer <- tf $ train $ AdamOptimizer (2e-4 , beta1 = 0.5 )
328+ discriminator_optimizer <- tf $ optimizers $ Adam (2e-4 , beta_1 = 0.5 )
329+ generator_optimizer <- tf $ optimizers $ Adam (2e-4 , beta_1 = 0.5 )
330330
331331checkpoint_dir <- " ./checkpoints_pix2pix"
332332checkpoint_prefix <- file.path(checkpoint_dir , " ckpt" )
@@ -387,17 +387,17 @@ train <- function(dataset, num_epochs) {
387387 })
388388 })
389389 generator_gradients <- gen_tape $ gradient(gen_loss ,
390- generator $ variables )
390+ generator $ trainable_variables )
391391 discriminator_gradients <- disc_tape $ gradient(disc_loss ,
392- discriminator $ variables )
392+ discriminator $ trainable_variables )
393393
394394 generator_optimizer $ apply_gradients(transpose(list (
395395 generator_gradients ,
396- generator $ variables
396+ generator $ trainable_variables
397397 )))
398398 discriminator_optimizer $ apply_gradients(transpose(
399399 list (discriminator_gradients ,
400- discriminator $ variables )
400+ discriminator $ trainable_variables )
401401 ))
402402
403403 })
0 commit comments