Search
The goal of this lab is to learn useful image representations using Self-Supervised Learning (SSL), specifically the SimCLR [1] method, without relying on ground-truth labels during the training phase. We will train a deep CNN from scratch to extract general visual features using a contrastive loss. The learned representation's quality will then be evaluated on a standard downstream task: image classification using a simple non-parametric classifier on the CIFAR-10 dataset.
We are providing the overall structure of the training in file ssl.py that can be found here. To fulfil this assignment, you need to submit this template into the upload system. Fill in the template file with the full self-supervised training pipeline. Don't rename the key functions and classes that are to be implemented.
ssl.py
The core idea behind SimCLR is contrastive learning. For each image in a batch, we generate two different versions (or “views”) of it using strong random data augmentations (e.g., cropping, color changes, rotation). These two views are considered a positive pair because they originate from the same underlying image. All other augmented views in the same batch (originating from different images) are considered negative pairs.
The learning objective is to train a network that pulls the representations of positive pairs closer together in an embedding space, while simultaneously pushing apart the representations of negative pairs. By learning to distinguish between different views of the same image and views of different images, the network implicitly learns meaningful visual features.
The SimCLR framework uses a specific network structure:
Encoder
Projection Head
Implement the class SimCLR which wraps the Encoder and Projection Head and executes the forward function.
SimCLR
This contrastive loss aims to maximize the agreement between positive pairs $(z_i,z_j)$ compared to negative pairs $(z_i,z_k)$. For a positive pair $(i,j)$ from a batch of $2N$ augmented views, the loss for view $i$ is:
$$ \ell_{i,j}=−\log \frac{\exp(\mathbf{z}_i \cdot \mathbf{z}_j / \tau)}{\sum^{2N}_{k=1} 1_{[k \neq i]} \, \exp(\mathbf{z}_i \cdot \mathbf{z}_k / \tau)} $$
where:
The loss essentially applies the standard softmax cross-entropy loss framework commonly used for classification. In the contrastive setting, the “classification” task is to correctly identify the positive pair (another view of the same image) among all other negative examples within the current batch, using feature cosine similarity as the logit.
Note: SimCLR uses the abbreviation of this loss as NT-Xent. In practice, you can encounter InfoNCE loss, which is essentially a generalization of this loss and is often used interchangeably.
Implement this loss function in xent_loss.
xent_loss
\begin{align*} \ell_{i,j} &= -\log \frac{\exp \left( \mathbf{z}_i \cdot \mathbf{z}_j / \tau \right)}{\exp \left( \mathbf{z}_i \cdot \mathbf{z}_j / \tau \right) + \exp \left( \mathbf{z}_i \cdot \mathbf{z}_k / \tau \right)} \\ &= \log \left( 1 + \exp \left( \frac{\mathbf{z}_i \cdot \mathbf{z}_k - \mathbf{z}_i \cdot \mathbf{z}_j}{\tau} \right) \right) \\ &\approx \exp \left( \frac{\mathbf{z}_i \cdot \mathbf{z}_k - \mathbf{z}_i \cdot \mathbf{z}_j}{\tau} \right) \quad \text{(Taylor expansion of log)} \\ &\approx 1 + \frac{1}{\tau} \cdot \left( \mathbf{z}_i \cdot \mathbf{z}_k - \mathbf{z}_i \cdot \mathbf{z}_j \right) \\ &= 1 - \frac{1}{2\tau} \cdot \left( \|\mathbf{z}_i - \mathbf{z}_k\|^2 - \|\mathbf{z}_i - \mathbf{z}_j\|^2 \right) \\ &\propto \|\mathbf{z}_i - \mathbf{z}_j\|^2 - \|\mathbf{z}_i - \mathbf{z}_k\|^2 + 2\tau \end{align*}
Generating informative positive pairs $(x_i,x_j)$ through strong data augmentation $t \sim \mathcal{T}$ is crucial for the success of SimCLR. The augmentations should significantly alter the image appearance while preserving its core semantic content.
Your task is to define a strong augmentation pipeline using torchvision.transforms. Recommended transforms include:
torchvision.transforms
transforms.RandomResizedCrop
transforms.RandomHorizontalFlip
transforms.RandomApply
transforms.ColorJitter
transforms.RandomGrayscale
transforms.GaussianBlur
Implement a dataset wrapper class MultiViewDataset to apply this transform pipeline twice independently to each input image, producing the two views needed for training. Experiment with different augmentations and compare their impact on training and test performance. Each image view should have a resolution of $224 \times 224$ after this preprocessing, to allow for batched processing.
MultiViewDataset
We will pre-train our model on ImageNette, a 10-class subset of ImageNet. The labels provided with ImageNette are ignored and the dataset is treated as unlabeled, mimicking a real-world scenario where large amounts of unlabeled data are available.
Implement the full training pipeline in the train function that extracts feature embeddings of $2*N$ image views and optimizes the network via the NT-Xent loss. You can follow the training implementation of the previous lab for inspiration.
train
We will evaluate the quality of the features learned during pre-training by testing their performance on a downstream task. This task involves classification of CIFAR-10 dataset with a non-parametric approach using a k-Nearest Neighbor classifier, specifically 1-NN. This tests how well the self-supervised features transfer to a standard classification task without further fine-tuning (as done in the previous lab).
Implement the function test to compute the accuracy through these steps:
test
(optional implementation detail for interested students)
SimCLR is known to benefit from larger batch sizes. This is due to the presence of more negative samples, increasing the probability of seeing hard negatives in the training. When training with GPU, its VRAM size is the limiting factor dictating the batch size it can accommodate. There are several tricks that can alleviate this constraint by trading for training speed or precision:
Besides these implementation tricks, a recent work has shown that removing the positive pair similarity from the loss denominator improves the performance for smaller batches [3].