Quick links: Schedule | Forum | BRUTE | Lectures | Labs

# Lab 6: Metric Learning

Image retrieval, mean average precision, training embeddings (representations) with triplet loss and smoothAP loss.

### Templates

Use the provided template. The package contains:

• tools.py — tools for working with dataset.
• lab.py — template to start with.
• view.ipynb — template for visualization of results.
• ./models/* — pretrained models.

### Introduction

In this lab we want to use a neural network to produce an embedding of images $x$ as feature vectors $f \in \mathbb{R}^d$. These feature vectors can be used then to retrieve similar images by finding the closest neighbors in the embedding space.

### Part 1. Retrieval (4p)

We will use the FashionMNIST dataset and start with a network trained for classification. We will use its last hidden layer representation as the embedding. The network is a small convolutional network defined as follows

class ConvNet(nn.Sequential):
def __init__(self, num_classes: int = 10) -> None:
layers = []
layers += [nn.Conv2d(1, 32, kernel_size=3)]
layers += [nn.ReLU(inplace=True)]
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
layers += [nn.Conv2d(32, 32, kernel_size=3)]
layers += [nn.ReLU(inplace=True)]
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
layers += [nn.Conv2d(32, 64, kernel_size=3)]
layers += [nn.ReLU(inplace=True)]
layers += [nn.Flatten()]
layers += [nn.Linear(64 * 2 * 2, num_classes)]
super().__init__(*layers)
self.layers = layers

def features(self, x):
f = nn.Sequential(*self.layers[:-1]).forward(x)
f = nn.functional.normalize(f, p=2, dim=1)
return f
It was trained for classification and achieves ~$90\%$ accuracy.

1. Implement image retrieval system. Follow the iPython notebook template. The template loads the test set and a pretrained model. The pretrained model has a method features(x) which, given an input batch of samples $x$, outputs a batch of normalized feature vectors $f$ (i.e. $\|f_{i,:}\|_2 = 1$ for all $i$). Your task is to retrieve the 50 closest images in the test set for each query image. Implement the function distances computing the squared Euclidean distance from the query feature vector to feature vectors of all images in the test set without using loops. Use the test set provided in the template (it contains 100 test examples of each class).

Hint: Use torch.cdist or exploit the fact that the feature vectors are normalized and simplify the computation of the squared Euclidean distance & learn how to use torch.einsum function (which of these is easier to understand in someone's code?).

Display the retrieved images as in the sample result below, where the first image in each row is the query and retrieved positives/negatives have green/red border:
2. Compute the Mean Average Precision (mAP). You will need the following notions. For each query we sort all items in the dataset by increasing distance from the query. For computing mAP, the query itself should be excluded from the set of retrieved items. Let the relevance ${\rm rel}(i)$ be 1 if the item at position $i$ in this list is of the same class as the query and $0$ otherwise. If we consider the first $k$ items in the list, the precision of the retrieval system at $k$ is defined as $${\rm prec}(k) = \frac{\sum_{i=1}^k {\rm rel}(i)}{k},$$ i.e. the proportion of relevant items in the list of $k$ retrieved items. The recall of the retrieval system at $k$ is defined as $${\rm recall}(k) = \frac{\sum_{i=1}^k {\rm rel}(i)}{T},$$ the ratio of relevant items in the list of length $k$ to the total number of relevant items in the dataset $T$ (images of the same class as the query). The average precision is defined as the area under the precision-recall curve which can be computed as $${\rm} AP = \sum_{k=1}^N{\rm prec}(k) \Delta {\rm recall}(k) = \sum_{k=1}^N{\rm prec}(k) \frac{{\rm rel}(k)}{T}.$$ where $\Delta {\rm recall}(k)$ is the change in recall from items $𝑘−1$ to $𝑘$. See also wikipedia. Note that $N$ means all examples in the dataset, i.e. we do not limit it to 50 best as in the previous task. The Mean Average Precision (mAP) is the mean of $AP$ over random queries.

Your task is to implement a function evaluate_AP (avoid using loops) which computes the average precision given distances from a query to all other data points and their labels. Using this function, implement evaluate_mAP(net, dataset) which computes the mean average precision using the first 100 (random) dataset points as queries, as provided in the template. The expected result for the classification network is ${\rm mAP} = 0.58$. Plot the precision - recall curve. Here we want to show mean precision and recall, i.e. averaged across all queries.

Report and discuss: method used to compute all Euclidean distances, figure with retrieved images from 1, precision-recall curve and mAP.

### Part 2. Metric Learning (6p)

In this part we will learn the embedding to directly facilitate the retrieval performance by utilizing specialized loss functions. We will use exactly the same network architecture as in Part 1 but we will train it differently.

#### Triplet loss

For an anchor (query) image $a$ let the positive example $p$ be of the same class and a negative example $n$ be of different class than $a$. For conciseness of notation we will define all the quantities and losses for a fixed anchor $a$. The losses are then to be considered in the expectation over $a$.

Let $d_{p} = d(f(a),f(p))$ and $d_{n} = d(f(a),f(n))$ be distances between positive and anchor, and negative and anchor, respectively. We want to learn the embedding $f$ such that all positive examples $p$ would be closer to the anchor than all negative examples: \begin{align}\label{triplet-constraint} d_{p} < d_{n} \ \ \ \forall p,n. \end{align}

The constraint \eqref{triplet-constraint} is violated if $d_{p} - d_{n} \geq 0$. We can use a hinge loss to penalize violation and add a margin to ensure that positives are strictly closer: \begin{align} l(a) = \sum_{p,n} \max(d_{p} - d_{n} + \alpha, 0). \end{align} Assume that we use squared euclidean distance for $d$, it has range $[0,4]$ and so the maximum violation of the constraint is $4$. A reasonable range for choosing the margin $\alpha$ is $[0,0.5]$.

1. Implement the function triplet_loss that given a batch of data points represented by their feature vectors and class labels, does the following. The first 10 entries in the batch are considered as anchors. For each anchor $a$ compute the total loss $l(a)$ where $n,p$ go over all possible valid choices in the given batch. Avoid loops in your implementation (with the exception of the loop over anchors $a$).
2. Implement the function train_triplets that performs SGD using batches of data and the triplet loss defined above. Batches of images are sampled in the standard way (images are drawn at random without replacement). Train the model on a GPU server for 100 epochs. Print the total triplet loss in the end of each epoch. Save the training history and the trained network.

#### SmoothAP loss

If our task is like the retrieval task, where mAP was proposed as a criterion to asses the quality of the image retrieval, we can design a loss function more suitable in order to improve mAP criterion directly.

In computing Mean Average Precision (mAP) we compute the average precision (AP) for a given query (anchor) and average over all queries. Therefore, we will again assume that $a$ is fixed, design approximations to AP and consider them in average.

Let $P$ be the set of all positive examples and $N$ the set of all negative examples for the anchor $a$. It can be shown that the average precision can be expressed as \begin{align}\label{AP-pairs} {\rm AP} = 1 - \frac{1}{T} \sum_{p\in P} \frac{\sum_{n\in N} [[d_n < d_p]]}{k(p)}, \end{align}

where $T = |P|$ is the total number of positive examples and \begin{align}\label{k(p)} k(p) = \sum_{x \in P \cup N} [[d_x \leq d_p]] \end{align}

In the expression \eqref{AP-pairs} the numerator counts the number of negative examples which have a smaller distance to the query than a positive example $p$, i.e. they will be incorrectly listed in a sorted list of retrieved items earlier than $p$. The function $k(p)$ expresses the position of $p$ in the sorted list of all examples. Efficiently $1/k(p)$ gives a higher relative weights to errors in the beginning of the retrieval list and discounts errors towards the end of the retrieval list.

We would wish to maximize mAP, therefore the equivalent loss to minimize is the average over positive examples \begin{align} l(a) = \frac{1}{T} \sum_{p\in P} \frac{\sum_{n\in N} [[d_n < d_p]]}{k(p)}, \end{align}

Since $l(a)$ it is composed of non-differentiable indicator functions, we cannot train the embedding $f$ to minimize it directly. We need to approximate (relax) the criterion to something differentiable.

A natural relaxation is to use sigmoid to smoothly approximate the step function. Namely \begin{align} \sigma_\tau(x) = \frac{1}{1 + e^{-x/\tau}} \end{align} approaches $[[x \geq 0]]$ when $\tau \rightarrow 0$. Using this relaxation we obtain a variant of the “Smooth AP” loss \begin{align}\label{AP-sigmoid} l_\text{Smooth-AP}(a) = \sum_{p\in P} \frac{\sum_{n\in N} \sigma_\tau(d_{p} - d_{n})}{k_{\sigma}(p)}, \end{align}

where

\begin{align} k_{\sigma}(p) = \sum_{x \in P \cup N} \sigma_\tau(d_{p} - d_{x}). \end{align}

Applying smoothing to all indicator functions in \eqref{AP-pairs}, including those occurring in $k(p)$ results in the method of A. Brown et al.: Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval (2020).

1. Implement the function smooth_AP_loss that given a batch of data points represented by their feature vectors and class labels, does the following: The first 10 entries in the batch are considered as anchors. For each anchor $a$ compute the total loss $l_\text{Smooth-AP}(a)$ where $n,p$ go over all possible valid choices in the given batch. Avoid loops in your implementation (with the exception of the loop over anchors $a$).
2. Implement the function train_smooth_AP that performs SGD using batches of data and the smoothAP loss defined above. Batches of images are sampled in the standard way (images are drawn at random without replacement). Train the model on a GPU server for 100 epochs. Print the total smoothAP loss in the end of each epoch. Save the training history and the trained network.

Report and discuss: Optimization settings, plot of the training progress using the saved history. Evaluation of the trained network as in Part 1 (examples of retrieval and mAP score). Precision-recall curve in comparison with that one for the classification-trained network (display in the same plot). The reference solution has ${\rm mAP=0.79}$ and ${\rm mAP=0.81}$ and the following retrieved images for the exemplar queries for triplet and smoothAP loss, respectively.

The reference precision-recall comparison: