Skip to content

Commit e1b22a5

Browse files
committed
transfer_learning.Rmd vignette edits
1 parent c0b5334 commit e1b22a5

File tree

2 files changed

+44
-50
lines changed

2 files changed

+44
-50
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
- `layer_cudnn_gru()` and `layer_cudnn_lstm()` are deprecated. `layer_gru()` and `layer_lstm()` will
2525
automatically use CuDNN if it is available.
2626

27+
- New vignette: "Transfer learning and fine-tuning".
28+
2729
- New function `%<-active%`, a ergonomic wrapper around `makeActiveBinding()`
2830
for constructing Python `@property` decorated methods in `%py_class%`.
2931

vignettes/new-guides/transfer_learning.Rmd

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
---
2-
title: "Transfer learning & fine-tuning"
2+
title: "Transfer learning and fine-tuning"
33
author: "[fchollet](https://twitter.com/fchollet), [t-kalinowski](https://github.com/t-kalinowski)"
44
date: 2021/10/15
5-
description: "Complete guide to transfer learning & fine-tuning in Keras."
5+
description: "Complete guide to transfer learning and fine-tuning in Keras."
66
output: rmarkdown::html_vignette
77
vignette: >
8-
%\VignetteIndexEntry{Transfer learning & fine-tuning}
8+
%\VignetteIndexEntry{Transfer learning and fine-tuning}
99
%\VignetteEngine{knitr::rmarkdown}
1010
%\VignetteEncoding{UTF-8}
1111
---
@@ -44,7 +44,7 @@ very low learning rate. This can potentially achieve meaningful improvements, by
4444
incrementally adapting the pretrained features to the new data.
4545

4646
First, we will go over the Keras `trainable` API in detail, which underlies most
47-
transfer learning & fine-tuning workflows.
47+
transfer learning and fine-tuning workflows.
4848

4949
Then, we'll demonstrate the typical workflow by taking a model pretrained on the
5050
ImageNet dataset, and retraining it on the Kaggle "cats vs dogs" classification
@@ -95,7 +95,7 @@ printf("trainable_weights: %s", length(layer$trainable_weights))
9595
printf("non_trainable_weights: %s", length(layer$non_trainable_weights))
9696
```
9797

98-
Layers & models also feature a boolean attribute `trainable`. Its value can be changed.
98+
Layers and models also feature a boolean attribute `trainable`. Its value can be changed.
9999
Setting `layer$trainable` to `FALSE` moves all the layer's weights from trainable to
100100
non-trainable. This is called "freezing" the layer: the state of a frozen layer won't
101101
be updated during training (either when training with `fit()` or when training with
@@ -138,25 +138,25 @@ final_layer1_weights_values <- get_weights(layer1)
138138
stopifnot(all.equal(initial_layer1_weights_values, final_layer1_weights_values))
139139
```
140140

141-
Do not confuse the `layer$trainable` attribute with the `training` argument in a layer instance's `call` signature
142-
`layer(training =)` (which controls whether the layer should run its forward pass in
143-
inference mode or training mode). For more information, see the
144-
[Keras FAQ](
141+
Do not confuse the `layer$trainable` attribute with the `training` argument in a
142+
layer instance's `call` signature `layer(training =)` (which controls whether
143+
the layer should run its forward pass in inference mode or training mode).
144+
For more information, see the [Keras FAQ](
145145
https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).
146146

147147
## Recursive setting of the `trainable` attribute
148148

149149
If you set `trainable = FALSE` on a model or on any layer that has sublayers,
150-
all children layers become non-trainable as well.
150+
all child layers become non-trainable as well.
151151

152152
**Example:**
153153
```{r}
154154
inner_model <- keras_model_sequential(input_shape = c(3)) %>%
155155
layer_dense(3, activation = "relu") %>%
156-
layer_dense(3, activation = "relu")
156+
layer_dense(3, activation = "relu")
157157
158-
model <- keras_model_sequential(input_shape = c(3)) %>%
159-
inner_model() %>%
158+
model <- keras_model_sequential(input_shape = c(3)) %>%
159+
inner_model() %>%
160160
layer_dense(3, activation = "sigmoid")
161161
162162
@@ -184,7 +184,7 @@ Note that an alternative, more lightweight workflow could also be:
184184
3. Use that output as input data for a new, smaller model.
185185

186186
A key advantage of that second workflow is that you only run the base model once on
187-
your data, rather than once per epoch of training. So it's a lot faster & cheaper.
187+
your data, rather than once per epoch of training. So it's a lot faster and cheaper.
188188

189189
An issue with that second workflow, though, is that it doesn't allow you to dynamically
190190
modify the input data of your new model during training, which is required when doing
@@ -217,29 +217,29 @@ Create a new model on top.
217217
```{r}
218218
inputs <- layer_input(c(150, 150, 3))
219219
220-
outputs <- inputs %>%
220+
outputs <- inputs %>%
221221
# We make sure that the base_model is running in inference mode here,
222222
# by passing `training=FALSE`. This is important for fine-tuning, as you will
223223
# learn in a few paragraphs.
224224
base_model(training=FALSE) %>%
225-
225+
226226
# Convert features of shape `base_model$output_shape[-1]` to vectors
227-
layer_global_average_pooling_2d() %>%
228-
227+
layer_global_average_pooling_2d() %>%
228+
229229
# A Dense classifier with a single unit (binary classification)
230230
layer_dense(1)
231-
231+
232232
model <- keras_model(inputs, outputs)
233233
```
234234

235235

236236
Train the model on new data.
237237

238238
```{r, eval = FALSE}
239-
model %>%
239+
model %>%
240240
compile(optimizer = optimizer_adam(),
241241
loss = loss_binary_crossentropy(from_logits = TRUE),
242-
metrics = metric_binary_accuracy()) %>%
242+
metrics = metric_binary_accuracy()) %>%
243243
fit(new_dataset, epochs = 20, callbacks = ..., validation_data = ...)
244244
```
245245

@@ -276,7 +276,7 @@ model %>% compile(
276276
optimizer = optimizer_adam(1e-5), # Very low learning rate
277277
loss = loss_binary_crossentropy(from_logits = TRUE),
278278
metrics = metric_binary_accuracy()
279-
)
279+
)
280280
281281
# Train end-to-end. Be careful to stop before you overfit!
282282
model %>% fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
@@ -301,21 +301,21 @@ Many image models contain `BatchNormalization` layers. That layer is a special c
301301
- `BatchNormalization` contains 2 non-trainable weights that get updated during
302302
training. These are the variables tracking the mean and variance of the inputs.
303303
- When you set `bn_layer$trainable = FALSE`, the `BatchNormalization` layer will
304-
run in inference mode, and will not update its mean & variance statistics. This is not
304+
run in inference mode, and will not update its mean and variance statistics. This is not
305305
the case for other layers in general, as
306-
[weight trainability & inference/training modes are two orthogonal concepts](
306+
[weight trainability and inference/training modes are two orthogonal concepts](
307307
https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).
308308
But the two are tied in the case of the `BatchNormalization` layer.
309309
- When you unfreeze a model that contains `BatchNormalization` layers in order to do
310310
fine-tuning, you should keep the `BatchNormalization` layers in inference mode by
311-
passing `training=TRUE` when calling the base model.
311+
passing `training = FALSE` when calling the base model.
312312
Otherwise the updates applied to the non-trainable weights will suddenly destroy
313313
what the model has learned.
314314

315315
You'll see this pattern in action in the end-to-end example at the end of this guide.
316316

317317

318-
## Transfer learning & fine-tuning with a custom training loop
318+
## Transfer learning and fine-tuning with a custom training loop
319319

320320
If instead of `fit()`, you are using your own low-level training loop, the workflow
321321
stays essentially the same. You should be careful to only take into account the list
@@ -363,7 +363,7 @@ while(!is.null(batch <- iter_next(new_dataset))) {
363363
gradients <- tape$gradient(loss_value, model$trainable_weights)
364364
# Update the weights of the model.
365365
optimizer$apply_gradients(xyz(gradients, model$trainable_weights))
366-
}
366+
}
367367
```
368368

369369

@@ -372,7 +372,7 @@ Likewise for fine-tuning.
372372
## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset
373373

374374
To solidify these concepts, let's walk you through a concrete end-to-end transfer
375-
learning & fine-tuning example. We will load the Xception model, pre-trained on
375+
learning and fine-tuning example. We will load the Xception model, pre-trained on
376376
ImageNet, and use it on the Kaggle "cats vs. dogs" classification dataset.
377377

378378
### Getting the data
@@ -400,7 +400,6 @@ c(train_ds, validation_ds, test_ds) %<-% tfds$load(
400400
printf("Number of training samples: %d", length(train_ds))
401401
printf("Number of validation samples: %d", length(validation_ds) )
402402
printf("Number of test samples: %d", length(test_ds))
403-
404403
```
405404

406405
These are the first 9 images in the training dataset -- as you can see, they're all
@@ -415,7 +414,7 @@ train_ds %>%
415414
iterate(function(batch) {
416415
c(image, label) %<-% batch
417416
plot(as.raster(image, max = 255))
418-
title(sprintf("label: %s size: %s",
417+
title(sprintf("label: %s size: %s",
419418
label, paste(dim(image), collapse = " x ")))
420419
})
421420
```
@@ -453,12 +452,12 @@ validation_ds %<>% dataset_map(function(x, y) list(tf$image$resize(x, size), y))
453452
test_ds %<>% dataset_map(function(x, y) list(tf$image$resize(x, size), y))
454453
```
455454

456-
Besides, let's batch the data and use caching & prefetching to optimize loading speed.
455+
Besides, let's batch the data and use caching and prefetching to optimize loading speed.
457456
```{r}
458457
dataset_cache_batch_prefetch <- function(dataset, batch_size = 32, buffer_size = 10) {
459-
dataset %>%
460-
dataset_cache() %>%
461-
dataset_batch(batch_size) %>%
458+
dataset %>%
459+
dataset_cache() %>%
460+
dataset_batch(batch_size) %>%
462461
dataset_prefetch(buffer_size)
463462
}
464463
@@ -474,19 +473,19 @@ When you don't have a large image dataset, it's a good practice to artificially
474473
the training images, such as random horizontal flipping or small random rotations. This
475474
helps expose the model to different aspects of the training data while slowing down
476475
overfitting.
477-
476+
478477
```{r}
479-
data_augmentation <- keras_model_sequential() %>%
480-
layer_random_flip("horizontal") %>%
478+
data_augmentation <- keras_model_sequential() %>%
479+
layer_random_flip("horizontal") %>%
481480
layer_random_rotation(.1)
482481
```
483482

484483
Let's visualize what the first image of the first batch looks like after various random
485484
transformations:
486485

487486
```{r}
488-
batch <- train_ds %>%
489-
dataset_take(1) %>%
487+
batch <- train_ds %>%
488+
dataset_take(1) %>%
490489
as_iterator() %>% iter_next()
491490
492491
c(images, labels) %<-% batch
@@ -499,8 +498,8 @@ plot_image <- function(image, main = deparse1(substitute(image))) {
499498
as.array() %>% # convert from tensor to R array
500499
as.raster(max = 255) %>%
501500
plot()
502-
503-
if(!is.null(main))
501+
502+
if(!is.null(main))
504503
title(main)
505504
}
506505
@@ -509,13 +508,6 @@ plot_image(first_image)
509508
plot_image(augmented_image)
510509
plot_image(data_augmentation(first_image, training = TRUE), "augmented 2")
511510
plot_image(data_augmentation(first_image, training = TRUE), "augmented 3")
512-
#
513-
# augmented_image %>%
514-
# k_squeeze(1) %>% # drop batch dim
515-
# as.array() %>% as.raster(max = 255) %>%
516-
# plot()
517-
# title(as.array(labels[1]))
518-
519511
```
520512

521513

@@ -549,12 +541,12 @@ inputs = layer_input(shape = c(150, 150, 3))
549541
550542
outputs <- inputs %>%
551543
data_augmentation() %>% # Apply random data augmentation
552-
544+
553545
# Pre-trained Xception weights requires that input be scaled
554546
# from (0, 255) to a range of (-1., +1.), the rescaling layer
555547
# outputs: `(inputs * scale) + offset`
556548
layer_rescaling(scale = 1 / 127.5, offset = -1) %>%
557-
549+
558550
# The base model contains batchnorm layers. We want to keep them in inference mode
559551
# when we unfreeze the base model for fine-tuning, so we make sure that the
560552
# base_model is running in inference mode here.

0 commit comments

Comments
 (0)