This page is located in archive. Go to the latest version of this course pages.

Lab 4: Finetuning

Fine-tuning a pretrained CNN for a new task.

Skills: creating dataset from an image folder, data preprocessing, loading pretrained models, working remotely with GPU server, training part of the model, hyper-parameter search.


In this lab we start from a model already pretrained on the ImageNet classification dataset (1000 categories and 1.2 million images) and try to adjust it for solving a small-scale but otherwise challenging classification problem.

  • This will allow to work with a large-scale model at moderate computational expenses, since our fine-tuning dataset is small.
  • We will see that a pretrained network has already learned powerful visual features, which will greatly simplify our task.
  • We will consider several fine-tuning variants, adjusting a part of the network or all layers.



Fortunately, many excellent pretrained architectures are available in pytorch. You can use one of the following models:

  1. VGG11 https://pytorch.org/hub/pytorch_vision_vgg/, which was the model considered in the CNN lecture.
  2. Squeezenet1_0 https://pytorch.org/hub/pytorch_vision_squeezenet/, which has much fewer parameters and uses ‘fire’ modules similar to the example in CNN lecture slide 18. It will be about 4 times faster to train but achieves somewhat lower accuracy on Imagenet.

import torchvision.models
model = torchvision.models.vgg11(pretrained=True)

You might get the 'CERTIFICATE_VERIFY_FAILED' error, meaning that it cannot connect on a secure connection to download the model. In this case use the non-secure workaround:

import torchvision.models
from torchvision.models.vgg import model_urls
# from torchvision.models.squeezenet import model_urls
for k in model_urls.keys():
    model_urls[k] = model_urls[k].replace('https://', 'http://')
model = torchvision.models.vgg11(pretrained=True)
# model = torchvision.models.squeezenet1_0(pretrained=True)

You can see the structure of the loaded model by calling print(model). You can also open the source defining the network architecture https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py. Usually it is defined as a hierarchy of Modules, where each Module is either an elementary layer (e.g. Conv2d, Linear, ReLU) or a container (e.g. Sequential).


Download the dataset we prepared:

  1. Butterflies (35Mb)

The dataset contain color images 224×224 pixels of 10 categories.

Part 1 (2p)

Here we will practice data loading and preprocessing techniques.

  • Create dataset and a loader for training images. We can use this existing dataset interface that loads images from the disk:
    from torchvision import datasets, transforms
    train_data = datasets.ImageFolder('../data/butterflies/train', transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=0)
  • Perform standartization of the data: on the training set compute mean and standard deviation per color channel over all pixels and all images in the training set. Think how to do it incrementally with mini-batches, not loading the whole dataset into memory at once.

We will put these constant values in the code as a preprocessing, not to recompute them over again.

  • Add transforms.Normalize with the statistics you found to your dataset constructor. This is in order to standardize (whiten) the input for better conditioned training and also in order to match what the pretrained model expects (it is trained on normalized Imagenet). Apply the same transform to the test dataset as well.
  • From the train dataset create two loaders: the loader used for optimizing hyperparameters (train_loader) and the loader used for validation (val_loader). This is similar to Lab2, using SubsetRandomSampler.

Part 2 (4p)

We will first try learning the last layer of the network on the new data. I.e. we will use the network as a feature extractor and learn a linear classifier on top of it, as if it was a logistic regression model on some features. We need to do the following:

  1. Load the vgg11 model
  2. Freeze all parameters of the model, so that they will not be trained, by

for param in model.parameters():
    param.requires_grad = False

  1. In you model architecture identify and delete the “classifier” part that maps “features” to scores of 1000 ImageNet classes.
  2. Add a new “classifier” module that consists of one or more linear layers, with randomly initialized weights and outputs scores for 10 classes (our datasets). If we construct Linear layers anew, their parameters are automatically randomly initialized and have the attribute requires_grad = True by default, i.e. will be trainable. Consider using torch.nn.BatchNorm1d (after linear layers) or torch.nn.Dropout (after activations) inside your classifier block.
  3. Train the network and choose best parameters by cross-validation. Find a suitable learning rate as follows. First roughly determine the learning rate order by trying learning rates $0.1, 0.01, 0.001, 0.0001$ and comparing the training loss in 5 epochs. Knowing the rough value, select a grid of 5 learning rate values around it with which to perform full cross-validation. Evaluate validation accuracy after each epoch (as in lab2) and keep track of the parameter vector that achieves the best validation accuracy (saving the best so far). This way we automatically select the epoch at which it was the best to stop. Choose the learning rate that achieves the best validation accuracy. Apply regularization if needed (e.g. dropout or weight decay).
  4. Report the full setup of learning that you used: base network, classifier architecture, optimizer, learning rate and other hyper-parameters. Report plots of training and validation metrics (loss, accuracy) versus epochs for the selected hyper-parameters. Report the final test classification accuracy.
  5. If the network makes errors on the test data (we expect a few). For these cases display and report: 1) the input test image, 2) its correct class label, 3) the class labels and network confidence (predictive probabilities) of the top 3 network predictions (classes with highest predictive probability).

Part 3 (2p)

Depending on the size of the dataset and how much it is different from Imagenet, the following option may give better results compared to training the last layer only.

  1. Finetune the whole model. For this load the pretrained model, do not freeze any parameters and replace the output layer by a new one (of the appropriate size and randomly initialized). A smaller learning rate is recommended for fine-tuning.
  2. Report parameters chosen, training loss and training and test accuracies achieved. Which of the fine-tuning approaches obtained the best result?

Part 4 (2p)

For the whole model finetuning (as described in Part 3) try the following regularization method, that aims at a stochastic smoothing of the loss function:

\begin{align*} &\hat \theta = \theta^t + \varepsilon, \ \ \ \ \varepsilon\sim \mathcal{N}(0,\sigma^2)\\ &\theta^{t+1} = \theta^t - \alpha \frac{d \mathcal{L}(\theta)}{d \theta} \Big|_{\theta = \hat \theta} \end{align*}

Practically it can be implemented as follows. For each training step

  1. clone the whole model and perturb all trainable parameters by a gaussian noise with standard deviation $\sigma$.
  2. compute gradient $g$ in the perturbed model
  3. copy the gradient from the perturbed model to the unperturbed model
  4. make the optimization step for the unperturbed model

The parameter $\sigma$ needs to be chosen by cross-validation. Fix the learning rate to the one found in Part 3.

Important Technicalities

  • The training and test modes for batch normalization layers differ and it is important to control the state by setting either model.eval(False) during training and model.eval(True) during validation or testing.
  • If you train on GPU server, do not forget to actually compute on GPU and not on CPU by moving the model and the input data there.
  • Please post on the forum if you find out some more important guidelines.
courses/bev033dle/labs/lab3_finetune/start.txt · Last modified: 2022/04/28 16:05 by shekhole