Skip to content

Commit 476f035

Browse files
Turgut AbdullayevTurgut Abdullayev
authored andcommitted
"k_cast_to_floatx" to "as.array"
1 parent fee2fa0 commit 476f035

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

vignettes/examples/eager_styletransfer.R

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
library(keras)
99
use_implementation("tensorflow")
10-
use_session_with_seed(7777, disable_gpu = FALSE, disable_parallel_cpu = FALSE)
1110
library(tensorflow)
12-
tfe_enable_eager_execution(device_policy = "silent")
1311

1412
library(purrr)
1513
library(glue)
@@ -210,13 +208,12 @@ run_style_transfer <- function(content_path,
210208
gram_matrix(feature))
211209

212210
init_image <- load_and_process_image(content_path)
213-
init_image <-
214-
tf$contrib$eager$Variable(init_image, dtype = "float32")
211+
init_image <- tf$Variable(tf$cast(init_image,dtype = "float32"))
215212

216213
optimizer <-
217-
tf$train$AdamOptimizer(learning_rate = 1,
218-
beta1 = 0.99,
219-
epsilon = 1e-1)
214+
tf$optimizers$Adam(learning_rate = 1,
215+
beta_1 = 0.99,
216+
epsilon = 1e-1)
220217

221218
c(best_loss, best_image) %<-% list(Inf, NULL)
222219
loss_weights <- list(style_weight, content_weight)
@@ -242,16 +239,16 @@ run_style_transfer <- function(content_path,
242239

243240
end_time <- Sys.time()
244241

245-
if (k_cast_to_floatx(loss) < best_loss) {
246-
best_loss <- k_cast_to_floatx(loss)
242+
if (as.array(loss) < best_loss) {
243+
best_loss <- as.array(loss)
247244
best_image <- init_image
248245
}
249246

250247
if (i %% 50 == 0) {
251248
glue("Iteration: {i}") %>% print()
252249
glue(
253-
"Total loss: {k_cast_to_floatx(loss)}, style loss: {k_cast_to_floatx(style_score)},
254-
content loss: {k_cast_to_floatx(content_score)}, total variation loss: {k_cast_to_floatx(variation_score)},
250+
"Total loss: {as.array(loss)}, style loss: {as.array(style_score)},
251+
content loss: {as.array(content_score)}, total variation loss: {as.array(variation_score)},
255252
time for 1 iteration: {(Sys.time() - start_time) %>% round(2)}"
256253
) %>% print()
257254

0 commit comments

Comments
 (0)