Skip to content

Commit ca11331

Browse files
committed
vignette tweaks
1 parent 4be23b2 commit ca11331

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

vignettes/new-guides/customizing_what_happens_in_fit.Rmd

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,18 @@ Let's see how that works.
6666
library(tensorflow)
6767
library(tfdatasets)
6868
library(keras)
69-
69+
library(magrittr, include.only = "%<>%")
7070
7171
# -- we start by defining some helpers we'll use later --
72-
# like zip() in python
73-
zip <- function(...) purrr::transpose(list(...))
72+
# xyz for zipper. Like zip() in python
73+
xyz <- function(...) purrr::transpose(list(...))
7474
7575
map_and_name <- function(.x, .f, ...) {
7676
out <- purrr::map(.x, .f[-2L], ...)
7777
names(out) <- purrr::map_chr(.x, .f[-3L], ...)
7878
out
7979
}
8080
81-
envir::import_from(magrittr, `%<>%`)
82-
8381
stopifnot(tf_version() >= "2.2") # Requires TensorFlow 2.2 or later.
8482
```
8583

@@ -118,7 +116,7 @@ CustomModel(keras$Model) %py_class% {
118116
trainable_vars <- self$trainable_variables
119117
gradients <- tape$gradient(loss, trainable_vars)
120118
# Update weights
121-
self$optimizer$apply_gradients(zip(gradients, trainable_vars))
119+
self$optimizer$apply_gradients(xyz(gradients, trainable_vars))
122120
# Update metrics (includes the metric that tracks the loss)
123121
self$compiled_metrics$update_state(y, y_pred)
124122
@@ -172,7 +170,7 @@ CustomModel(keras$Model) %py_class% {
172170
gradients <- tape$gradient(loss, trainable_vars)
173171
174172
# Update weights
175-
self$optimizer$apply_gradients(zip(gradients, trainable_vars))
173+
self$optimizer$apply_gradients(xyz(gradients, trainable_vars))
176174
177175
# Compute our own metrics
178176
loss_tracker$update_state(loss)
@@ -237,7 +235,7 @@ CustomModel(keras$Model) %py_class% {
237235
gradients <- tape$gradient(loss, trainable_vars)
238236
239237
# Update weights
240-
self$optimizer$apply_gradients(zip(gradients, trainable_vars))
238+
self$optimizer$apply_gradients(xyz(gradients, trainable_vars))
241239
242240
# Update the metrics.
243241
# Metrics are configured in `compile()`.
@@ -333,15 +331,15 @@ generator <- keras_model_sequential(name = "generator", input_shape = c(latent_d
333331
layer_conv_2d(1, c(7, 7), padding = "same", activation = "sigmoid")
334332
```
335333

336-
Here's a feature-complete GAN class, overriding `compile()` to use its own signature, and implementing the entire GAN algorithm in 17 lines in `train_step`:
334+
Here's a feature-complete GAN class, overriding `compile()` to use its own signature, and implementing the entire GAN algorithm in just a few lines in `train_step`:
337335

338336
```{r}
339337
GAN(keras$Model) %py_class% {
340338
`__init__` <- function(discriminator, generator, latent_dim) {
341-
super()$`__init__`()
342-
self$discriminator <- discriminator
343-
self$generator <- generator
344-
self$latent_dim <- latent_dim
339+
super()$`__init__`()
340+
self$discriminator <- discriminator
341+
self$generator <- generator
342+
self$latent_dim <- latent_dim
345343
}
346344
347345
compile <- function(d_optimizer, g_optimizer, loss_fn) {
@@ -379,7 +377,7 @@ GAN(keras$Model) %py_class% {
379377
d_loss <- self$loss_fn(labels, predictions)
380378
})
381379
grads <- tape$gradient(d_loss, self$discriminator$trainable_weights)
382-
self$d_optimizer$apply_gradients(zip(grads, self$discriminator$trainable_weights))
380+
self$d_optimizer$apply_gradients(xyz(grads, self$discriminator$trainable_weights))
383381
384382
385383
# Sample random points in the latent space
@@ -395,10 +393,9 @@ GAN(keras$Model) %py_class% {
395393
g_loss <- self$loss_fn(misleading_labels, predictions)
396394
})
397395
grads <- tape$gradient(g_loss, self$generator$trainable_weights)
398-
self$g_optimizer$apply_gradients(zip(grads, self$generator$trainable_weights))
396+
self$g_optimizer$apply_gradients(xyz(grads, self$generator$trainable_weights))
399397
list(d_loss = d_loss, g_loss = g_loss)
400398
}
401-
402399
}
403400
```
404401

@@ -419,7 +416,6 @@ dataset <- all_digits %>%
419416
dataset_shuffle(buffer_size = 1024) %>%
420417
dataset_batch(batch_size)
421418
422-
423419
gan <- GAN(discriminator = discriminator,
424420
generator = generator,
425421
latent_dim = latent_dim)

0 commit comments

Comments
 (0)