SimCLR - Contrastive Learning of Visual Representations
February 2, 2021
Supervised learning and unsupervised learning techniques are very limited in their real-world applications. For example, the supervised learning technique can be very labor-intensive, computationally expensive, and time-consuming as it requires a lot of labeling to be performed on data.
What is Self-Supervised learning?
On the other hand, unsupervised learning (which involves using data without labels) hasn’t provided very meaningful information, especially when solving real-life cases.
So, how do we overcome these challenges?
We can overcome these challenges by using self-supervised learning. It is a subset of unsupervised learning that aims at mimicking how humans and animals learn. It automatically generates a supervisory signal that helps to solve tasks. For example, it can automatically help label a dataset or learn representations in data without any human help. Unlike unsupervised learning, it is important to note that it doesn’t organize data into clusters and groupings.
It has been used extensively in reinforcement learning, natural language processing, robotics, and neural networks.
Self-supervised learning helps create data-efficient AI systems.
Before reading this article, the reader must understand Machine Learning (ML) and Deep Learning (DL). If you are still new to ML and DL, please read my previously published article on the differences between Artificial Intelligence, Machine Learning, and Deep Learning to learn more.
What is contrastive learning?
Contrastive learning is a very active area in machine learning research. It is a self-supervised method used in machine learning to put together the task of finding similar and dissimilar things. By applying this method, one can train a machine learning model to contrast similarities between images. For example, given an image of a horse, one can find the matching animal in a gallery of other photos.
SimCLR is a framework developed by Google that demonstrated the implications of contrastive learning. It is high-impact work that eliminates specialized architectures and memory banks, typically used in contrastive learning. It shows that strong augmentations of unlabeled training data, a standard ResNet-50 architecture, and a small neural network is all you need to achieve state-of-the-art results. For such a simple approach, the results are truly mindblowing.
Throughout this article, this paper published by Google will be our referencing material for the article. Unlike other papers, this paper entails many tips like the network having a large batch size, more training epochs, and increasing the network’s width to make the most out of contrastive learning.
Let’s learn about the SimCLR framework’s details and the results presented in the paper published by Google.
An overview of the SimCLR framework
The major components of the SimCLR framework include:
- Data Augmentation
- A Base Encoder $f(x)$
- A Projection Head $g(h)$
- The Contrastive Loss Function.
The SimCLR framework starts by fetching examples of images from an original dataset. It transforms the given image example into two corresponding views of the same example image.
While previous methods to contrastive learning introduced architecture changes, SimCLR argues that a target image’s random cropping sets up enough context for contrastive learning. The use of cropping enables the network to learn the global to local contrast and contrast the same image’s adjacent views.
For example, consider this image below of the dog, with its global and local contrast.
Having learned about the contrast between the global and local views, the network is now able to contrast between the adjacent views of the same image shown below.
The paper also mentions a systematic study performed, that combined the different compositions of data augmentations—for example, combining cropping with other data augmentation techniques such as blur, color distortion, and noise. This is shown below.
The results showed that combining cropping with color distortion and the gaussian blur stood out in terms of the result obtained’s accuracy. Thus, they chose to use random crop (with flip and resize), color distortion, and gaussian blur in their augmentation policy and dropped the other augmentation techniques.
A base encoder $f(x)$
The base encoder $f(x)$ uses a Convolutional Neural Network (CNN) variant based on the ResNet architecture. It extracts image representation vectors from the augmented data images produced by the data augmentation module. This extraction produces the embeddings, $h$.
A projection head $g(h)$
The projection head $g(h)$ consists of two fully-connected layers, i.e., a multi-layer perceptron (MLP), that takes in the embeddings, $h$, as its inputs from the base encoder and produces an embedding $z$. This module’s role is to map the image representations to a latent space where contrastive loss is applied.
The contrastive loss function (normalized temperature-scaled cross entropy (NT-Xent loss))
The contrastive loss function is a modified version of the cross-entropy loss function, which is the most widely used loss function for supervised learning of deep classification models. The function is shown below.
The contrastive loss function states that the similarity of $z_i$, and $z_j$ corresponding to, for example, an image of a cat and its augmentation should be closer together. In other words, they should attract.
In contrast, the similarity of any $k$, which is not $i$, should be pushed further apart (repel). An example of this would be the representation of a dog, and a cat should repel eachother.
That’s a simplistic view of what the contrastive loss function does in a nutshell.
Results of SimCLR
- One of the key findings of this paper is that self-supervised learning algorithms benefit more from scaling up than supervised learning algorithms.
- The experiment batch sizes range from 256 to 8,192. The experiment found that the accuracy kept increasing as the batch sizes and the number of epochs increased, as shown below.
- An experiment on the ResNet-50 architecture with three different widths showed a significant gain in accuracy with an increase in the model width. These results are shown below.
- The main results from the paper demonstrates that SimCLR wins against other state-of-the-art methods both with or without ResNet-50. From the results below, SimCLR can achieve a 76.5% top-1 accuracy. This is a 7% improvement over previous state-of-the-art models, which matches the performance of a ResNet-50, which has a more advanced architecture.
Summary of the SimCLR Framework
- Use random crop (with flip and resize), color distortion, and gaussian blur as they are the best data augmentation techniques in contrastive learning.
- Use a large batch size whenever possible, especially if you have enough GPU compute power.
- Train your model for longer epochs to achieve better results. In this paper, they trained them for 1,000 epochs.
- The projection head g(x) is important to get good representations.
- The framework learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space.
- Increasing the model depth and width brings greater benefits to contrastive learning.
- The non-linear projection head g(h) is used to increase the representative power of h.
- The updating of the parameters in contrastive learning causes the representations with correlating views to attract each other, while representations with non-correlating views repel each other.
- Advancing Self-Supervised and Semi-Supervised Learning with SimCLR
- A Simple Framework for Contrastive Learning of Visual Representations
- Contrastive Representation Learning: A Framework and Review
- Supervised Contrastive Learning
Peer Review Contributions by: Lalithnarayan C
About the authorWillies Ogola
Willies Ogola is pursuing his Master’s in Computer Science in Hubei University of Technology, China. His research direction is on Artificial Intelligence and Embedded Systems. He likes researching during his free time and is passionate about technology.