11---
2- title : " Transfer learning & fine-tuning"
2+ title : " Transfer learning and fine-tuning"
33author : " [fchollet](https://twitter.com/fchollet), [t-kalinowski](https://github.com/t-kalinowski)"
44date : 2021/10/15
5- description : " Complete guide to transfer learning & fine-tuning in Keras."
5+ description : " Complete guide to transfer learning and fine-tuning in Keras."
66output : rmarkdown::html_vignette
77vignette : >
8- %\VignetteIndexEntry{Transfer learning & fine-tuning}
8+ %\VignetteIndexEntry{Transfer learning and fine-tuning}
99 %\VignetteEngine{knitr::rmarkdown}
1010 %\VignetteEncoding{UTF-8}
1111---
@@ -44,7 +44,7 @@ very low learning rate. This can potentially achieve meaningful improvements, by
4444 incrementally adapting the pretrained features to the new data.
4545
4646First, we will go over the Keras ` trainable ` API in detail, which underlies most
47- transfer learning & fine-tuning workflows.
47+ transfer learning and fine-tuning workflows.
4848
4949Then, we'll demonstrate the typical workflow by taking a model pretrained on the
5050ImageNet dataset, and retraining it on the Kaggle "cats vs dogs" classification
@@ -95,7 +95,7 @@ printf("trainable_weights: %s", length(layer$trainable_weights))
9595printf("non_trainable_weights: %s", length(layer$non_trainable_weights))
9696```
9797
98- Layers & models also feature a boolean attribute ` trainable ` . Its value can be changed.
98+ Layers and models also feature a boolean attribute ` trainable ` . Its value can be changed.
9999Setting ` layer$trainable ` to ` FALSE ` moves all the layer's weights from trainable to
100100non-trainable. This is called "freezing" the layer: the state of a frozen layer won't
101101be updated during training (either when training with ` fit() ` or when training with
@@ -138,25 +138,25 @@ final_layer1_weights_values <- get_weights(layer1)
138138stopifnot(all.equal(initial_layer1_weights_values, final_layer1_weights_values))
139139```
140140
141- Do not confuse the ` layer$trainable ` attribute with the ` training ` argument in a layer instance's ` call ` signature
142- ` layer(training =) ` (which controls whether the layer should run its forward pass in
143- inference mode or training mode). For more information, see the
144- [ Keras FAQ] (
141+ Do not confuse the ` layer$trainable ` attribute with the ` training ` argument in a
142+ layer instance's ` call ` signature ` layer(training =) ` (which controls whether
143+ the layer should run its forward pass in inference mode or training mode).
144+ For more information, see the [ Keras FAQ] (
145145 https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute ).
146146
147147## Recursive setting of the ` trainable ` attribute
148148
149149If you set ` trainable = FALSE ` on a model or on any layer that has sublayers,
150- all children layers become non-trainable as well.
150+ all child layers become non-trainable as well.
151151
152152** Example:**
153153``` {r}
154154inner_model <- keras_model_sequential(input_shape = c(3)) %>%
155155 layer_dense(3, activation = "relu") %>%
156- layer_dense(3, activation = "relu")
156+ layer_dense(3, activation = "relu")
157157
158- model <- keras_model_sequential(input_shape = c(3)) %>%
159- inner_model() %>%
158+ model <- keras_model_sequential(input_shape = c(3)) %>%
159+ inner_model() %>%
160160 layer_dense(3, activation = "sigmoid")
161161
162162
@@ -184,7 +184,7 @@ Note that an alternative, more lightweight workflow could also be:
1841843 . Use that output as input data for a new, smaller model.
185185
186186A key advantage of that second workflow is that you only run the base model once on
187- your data, rather than once per epoch of training. So it's a lot faster & cheaper.
187+ your data, rather than once per epoch of training. So it's a lot faster and cheaper.
188188
189189An issue with that second workflow, though, is that it doesn't allow you to dynamically
190190modify the input data of your new model during training, which is required when doing
@@ -217,29 +217,29 @@ Create a new model on top.
217217``` {r}
218218inputs <- layer_input(c(150, 150, 3))
219219
220- outputs <- inputs %>%
220+ outputs <- inputs %>%
221221 # We make sure that the base_model is running in inference mode here,
222222 # by passing `training=FALSE`. This is important for fine-tuning, as you will
223223 # learn in a few paragraphs.
224224 base_model(training=FALSE) %>%
225-
225+
226226 # Convert features of shape `base_model$output_shape[-1]` to vectors
227- layer_global_average_pooling_2d() %>%
228-
227+ layer_global_average_pooling_2d() %>%
228+
229229 # A Dense classifier with a single unit (binary classification)
230230 layer_dense(1)
231-
231+
232232model <- keras_model(inputs, outputs)
233233```
234234
235235
236236Train the model on new data.
237237
238238``` {r, eval = FALSE}
239- model %>%
239+ model %>%
240240 compile(optimizer = optimizer_adam(),
241241 loss = loss_binary_crossentropy(from_logits = TRUE),
242- metrics = metric_binary_accuracy()) %>%
242+ metrics = metric_binary_accuracy()) %>%
243243 fit(new_dataset, epochs = 20, callbacks = ..., validation_data = ...)
244244```
245245
@@ -276,7 +276,7 @@ model %>% compile(
276276 optimizer = optimizer_adam(1e-5), # Very low learning rate
277277 loss = loss_binary_crossentropy(from_logits = TRUE),
278278 metrics = metric_binary_accuracy()
279- )
279+ )
280280
281281# Train end-to-end. Be careful to stop before you overfit!
282282model %>% fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
@@ -301,21 +301,21 @@ Many image models contain `BatchNormalization` layers. That layer is a special c
301301- ` BatchNormalization ` contains 2 non-trainable weights that get updated during
302302training. These are the variables tracking the mean and variance of the inputs.
303303- When you set ` bn_layer$trainable = FALSE ` , the ` BatchNormalization ` layer will
304- run in inference mode, and will not update its mean & variance statistics. This is not
304+ run in inference mode, and will not update its mean and variance statistics. This is not
305305the case for other layers in general, as
306- [ weight trainability & inference/training modes are two orthogonal concepts] (
306+ [ weight trainability and inference/training modes are two orthogonal concepts] (
307307 https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute ).
308308But the two are tied in the case of the ` BatchNormalization ` layer.
309309- When you unfreeze a model that contains ` BatchNormalization ` layers in order to do
310310fine-tuning, you should keep the ` BatchNormalization ` layers in inference mode by
311- passing ` training=TRUE ` when calling the base model.
311+ passing ` training = FALSE ` when calling the base model.
312312Otherwise the updates applied to the non-trainable weights will suddenly destroy
313313what the model has learned.
314314
315315You'll see this pattern in action in the end-to-end example at the end of this guide.
316316
317317
318- ## Transfer learning & fine-tuning with a custom training loop
318+ ## Transfer learning and fine-tuning with a custom training loop
319319
320320If instead of ` fit() ` , you are using your own low-level training loop, the workflow
321321stays essentially the same. You should be careful to only take into account the list
@@ -363,7 +363,7 @@ while(!is.null(batch <- iter_next(new_dataset))) {
363363 gradients <- tape$gradient(loss_value, model$trainable_weights)
364364 # Update the weights of the model.
365365 optimizer$apply_gradients(xyz(gradients, model$trainable_weights))
366- }
366+ }
367367```
368368
369369
@@ -372,7 +372,7 @@ Likewise for fine-tuning.
372372## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset
373373
374374To solidify these concepts, let's walk you through a concrete end-to-end transfer
375- learning & fine-tuning example. We will load the Xception model, pre-trained on
375+ learning and fine-tuning example. We will load the Xception model, pre-trained on
376376 ImageNet, and use it on the Kaggle "cats vs. dogs" classification dataset.
377377
378378### Getting the data
@@ -400,7 +400,6 @@ c(train_ds, validation_ds, test_ds) %<-% tfds$load(
400400printf("Number of training samples: %d", length(train_ds))
401401printf("Number of validation samples: %d", length(validation_ds) )
402402printf("Number of test samples: %d", length(test_ds))
403-
404403```
405404
406405These are the first 9 images in the training dataset -- as you can see, they're all
@@ -415,7 +414,7 @@ train_ds %>%
415414 iterate(function(batch) {
416415 c(image, label) %<-% batch
417416 plot(as.raster(image, max = 255))
418- title(sprintf("label: %s size: %s",
417+ title(sprintf("label: %s size: %s",
419418 label, paste(dim(image), collapse = " x ")))
420419 })
421420```
@@ -453,12 +452,12 @@ validation_ds %<>% dataset_map(function(x, y) list(tf$image$resize(x, size), y))
453452test_ds %<>% dataset_map(function(x, y) list(tf$image$resize(x, size), y))
454453```
455454
456- Besides, let's batch the data and use caching & prefetching to optimize loading speed.
455+ Besides, let's batch the data and use caching and prefetching to optimize loading speed.
457456``` {r}
458457dataset_cache_batch_prefetch <- function(dataset, batch_size = 32, buffer_size = 10) {
459- dataset %>%
460- dataset_cache() %>%
461- dataset_batch(batch_size) %>%
458+ dataset %>%
459+ dataset_cache() %>%
460+ dataset_batch(batch_size) %>%
462461 dataset_prefetch(buffer_size)
463462}
464463
@@ -474,19 +473,19 @@ When you don't have a large image dataset, it's a good practice to artificially
474473the training images, such as random horizontal flipping or small random rotations. This
475474helps expose the model to different aspects of the training data while slowing down
476475 overfitting.
477-
476+
478477``` {r}
479- data_augmentation <- keras_model_sequential() %>%
480- layer_random_flip("horizontal") %>%
478+ data_augmentation <- keras_model_sequential() %>%
479+ layer_random_flip("horizontal") %>%
481480 layer_random_rotation(.1)
482481```
483482
484483Let's visualize what the first image of the first batch looks like after various random
485484 transformations:
486485
487486``` {r}
488- batch <- train_ds %>%
489- dataset_take(1) %>%
487+ batch <- train_ds %>%
488+ dataset_take(1) %>%
490489 as_iterator() %>% iter_next()
491490
492491c(images, labels) %<-% batch
@@ -499,8 +498,8 @@ plot_image <- function(image, main = deparse1(substitute(image))) {
499498 as.array() %>% # convert from tensor to R array
500499 as.raster(max = 255) %>%
501500 plot()
502-
503- if(!is.null(main))
501+
502+ if(!is.null(main))
504503 title(main)
505504}
506505
@@ -509,13 +508,6 @@ plot_image(first_image)
509508plot_image(augmented_image)
510509plot_image(data_augmentation(first_image, training = TRUE), "augmented 2")
511510plot_image(data_augmentation(first_image, training = TRUE), "augmented 3")
512- #
513- # augmented_image %>%
514- # k_squeeze(1) %>% # drop batch dim
515- # as.array() %>% as.raster(max = 255) %>%
516- # plot()
517- # title(as.array(labels[1]))
518-
519511```
520512
521513
@@ -549,12 +541,12 @@ inputs = layer_input(shape = c(150, 150, 3))
549541
550542outputs <- inputs %>%
551543 data_augmentation() %>% # Apply random data augmentation
552-
544+
553545 # Pre-trained Xception weights requires that input be scaled
554546 # from (0, 255) to a range of (-1., +1.), the rescaling layer
555547 # outputs: `(inputs * scale) + offset`
556548 layer_rescaling(scale = 1 / 127.5, offset = -1) %>%
557-
549+
558550 # The base model contains batchnorm layers. We want to keep them in inference mode
559551 # when we unfreeze the base model for fine-tuning, so we make sure that the
560552 # base_model is running in inference mode here.
0 commit comments