This is a blogpost for the paper 'Self-Supervised Relational Reasoning for Representation Learning' by Massimiliano Patacchiola and Amos Storkey.
(...In Progress...)
Introduction
Problem Statement
Manual annotation of an unlabeled dataset is costly and alongside brings limitations in deep learning tasks. In order to learn useful representations from unlabeled dataset (i.e., representation learning), self-supervised learning is proposed[1][2]. Self-supervised networks have commonly the following pipeline: a self-supervised pretext task training using unlabeled data to learn representations of the data and a supervised downstream task training using the learned representations from the pretext task[3]. The features, which are learnt in pretext task, are used in downstream tasks in order to solve the main task (e.g. classification, image retrieval).
According to studies[4], self-supervised human learning depends on establishing comparisons and relations between entities to quantify the relationship between objects. In this way, the unnecessary perceptual features can be ignored, and learner can focus on more relevant features. Since we want to mimic human learning, the study proposes a relational reasoning mechanism in addition to self-supervised learning method.
Methodology
To address the problem statement, a new formulation of relational reasoning, which is used to learn representation as pretext task in this work, is proposed. The relation head is added to the pretext model in comparison to previous work to train the backbone for getting useful representations of the data. The trained backbone could be utilized in downstream tasks such as classification and image retrieval.
Figure 1: Overview of the proposed method
As mentioned before, an unlabeled dataset D=〖{xn}〗(n=1)^N is considered. Out of this dataset, a mini-batch B={xm} (m=1..M) is generated. K number of data-augmentations are applied to every mini-batch. After generating K- augmented mini-batches (B(1),…B(K)), a forward pass is represented in backbone module which is counted as one of the main parts of the proposed method with relation module. In backbone a non-linear function f_theta(.) is modeled as a neural network which is parametrized by a vector of weights. The forward pass generates a representation vector f_theta(x_n(i))=z_n(i) in which i denotes the i-th set of random augmentations of all instances. The representation vectors create K representation sets for each mini batch which is written as Z(1),…,Z(K). Representations are coupled using an aggregation function. The tuples are evaluated as positive and negative pairs as follows:
Since commutative aggregation functions are chosen for the task, the identical pairs are discarded and only one of them is used.
After aggregation part, the relation module, which is a function r_theta(.) with a weight vector of theta, is presented. It is modeled as a multi-layer perceptron (MLP) and it outputs a scalar relation score y, using a sigmoid activation function. Relation score has the following properties and can be interpreted as a probabilistic estimate of representation membership:
After calculating the probability through relation module, binary cross entropy loss is minimized between the score y and the target value t with ti=1 for positive pairs and ti=0 for negative pairs
using the optimal weight scalar factor wi where γ≥0
The factor gamma puts more focus on misclassified data points and reduces the loss for correctly classified data points. It is known as focal loss[5].
The objective consists of an addition of two different losses, which are interpreted as intra- and inter-reasoning parts. Intra-reasoning part focuses on relation between views of the same object with random augmentation (positive pairs) and inter-reasoning part focuses on the relation between different objects in different scenes with random augmentation (negative pairs). The relation of an object to the same objects (intra-reasoning) and to other objects (inter-reasoning) is used by the backbone to accomplish downstream goals.
Experiments and Results
Dataset
In this study, the authors work on their method with CIFAR-10, CIFAR-100, CIFAR-100-20, STL-10, tiny-ImageNet and SlimageNet. CIFAR-10 and CIFAR-100 consist of 32x32 RGB images with 10 and 100 classes. The difference between CIFAR-100 and CIFAR-100-20 is the 100 classes in CIFAR-100-20 dataset are grouped into 20 ‘coarse’ labels which are also called as superclasses. STL-10 dataset has 10 classes with 96x96 RGB images. In comparison to CIFAR datasets, a huge number of unlabeled data, which comes from a similar but different distribution from the labeled data, is provided for unsupervised learning.
Tiny-ImageNet and SlimageNet are extracted from ImageNet. They both have a high number of classes with 200 and 1000 classes respectively. Each dataset has a resolution of 64x64 for RGB images, which makes the training of this datasets more challenging in comparison to ImageNet with 482x418 average resolution of images.
Evaluation Baselines
- Supervised Learning: The upper bound of the baselines are represented by the standard supervised method. Standard data augmentation with SGD optimization (a learning rate of 0.1 at the beginning, 0.01 at the 50% and 0.001 at 75% of total epochs) is used.
- Random weights: The baseline is initialized with standard fan-in/fan-out weights initialization. After initialization, linear evaluation is performed which optimizes the last layer. This method counts as lower bound since the backbone is not trained using backpropagation.
- DeepCluster: An unsupervised learning approach for visual features using k-means algorithm and pseudo-labels[6].
- RotationNet: A self-supervised learning approach for recognizing 2d rotations (00 , 900, 1800,2700) of an image[7].
- Deep InfoMax: An unsupervised learning approach for learning representations using maximization of the mutual information between input and encoded features.[8]
- SimCLR: A constractive learning method for learning effective visual representations[9]. SimCLR is the state-of-art baseline in self-supervised learning.
Implementation
- Mini-batch size: 64 images
- Total number of augmentations K: K=16 for ResNet-32 on tiny-ImageNet, K=25 for ResNet-34 on STL-10, K=32 for the others.
- Optimizer: Adam optimizer with a learning rate of 10-3
- Loss function: Binary cross-entropy loss with focal factor gamma=2
- Relation module: MLP with 256 hidden units using batch-norm normalization and leaky ReLU activation function and a sigmoid function as output unit
- Aggregation function: Concatenation function, since it is more effective
- Augmentations with percentage of chances: horizontal flip (50%), random crop-size (20%), conversion to gray-scale (20%), color gitter (80%).
- Backbones: Conv-4, ResNet-8/32/56, ResNet-34
Training Methods and Results
Linear evaluation: The linear evaluation protocol is followed[10]. The backbone is trained for 200 epochs using the unlabeled dataset. Then, a linear classifier is coupled on the top of the backbone features and training for 100 epochs is performed without updating the weights in backbone. The accuracy of the classifier on test set is the metric, which shows the success of the network.
The performances of the proposed model and other baselines are portrayed on Table 1 with best results bolded. In comparison to other baselines, one can see that the accuracy of relational reasoning using ResNet-32 backbone with CIFAR-100 (an accuracy of 46.2%) and tiny-ImageNet (an accuracy of 30.5%) is outstanding. This result shows that the proposed model provides an improvement of +4.0% and +4.7% over the best competitor (SimCLR).
Domain transfer: The authors perform transfer learning by training the backbone on unlabeled CIFAR-10 with linear evaluation on CIFAR-100 (10 🡪 100) and viceversa (100 🡪 10). As one can see from Table 1, relational reasoning approach performs the best in comparison to other approaches. The results of 10 🡪 100 case shows that relational reasoning has an improvement of +5.3% over the best competitor (SimCLR) and +7.5% over the upper bound (supervised baseline).
Grain: Since the size of the dataset influences the performance of the representation, the authors train the backbone on unlabeled CIFAR-100 with linear evaluation on labeled CIFAR-100 using 100 labels (Fine-Grained) and labeled CIFAR-100-20 using 20 superclasses (Coarse-Grained). The proposed method has the best performance with an accuracy of 52.4% on CIFAR-100-20 in comparison to other baselines, see Table 1.
Finetuning: Various methods are trained on unlabeled STL-10 (100K images) for 300 epochs and the pretrained backbone is finetuned with labeled STL-10 dataset (5K images). After the finetuning, the model is evaluated using test set (8K images). One can observe an outstanding result from relational reasoning with an accuracy of 89.67% in comparison to other baselines.
Table 1: Accuracy and variance (over three runs results) of different benchmarks using ResNet-32.
Depth of the backbone: The impact of the backbone depth on accuracy is portrayed in Figure 2a. The figure represents the accuracy difference, if ResNet-32 is used on CIFAR-10, CIFAR-100 and tiny-ImageNet (10,100 and 200 classes respectively) instead of Conv-4. The result is especially impressive given that the accuracy of relational reasoning achieves an increasingly better performance than the other competitors, when the complexity of the dataset increases.
Figure 2: Accuracy evaluation on different settings: (a) Depth of backbone, (b) Number of augmentations, (c) Number of available labels
Number of augmentations K: The correlation between the accuracy and parameter K is investigated, see Figure 2b (Linear evaluation on CIFAR-10 using Conv-4 backbone for 100 epochs). Depending on the difference between the accuracy with K=2 and other variants, proposed relational learning method has positive correlation despite other baselines.
Number of labels: The authors examine the accuracy of the proposed method w.r.t percentage of available labels. One can see that the accuracy is positively correlated with the number of available labels in both dataset (CIFAR-10 and CIFAR-100), see Figure 2c. The accuracy of relational learning converges to the supervised upper bound when all data is labeled.
Aggregation function: The performances of four different aggregation functions are presented; sum, mean, maximum and concatenation, see Table 2. The results show that the concatenation function is the most effective one with an accuracy of 60.81% on CIFAR-10 and 32.36% on CIFAR-100 between other options.
Table 2: Performances of different aggregation functions.
Architecture of relation head: Structural differences in relation head is analyzed by the authors. The relation head is discarded, and three different architectures are evaluated; dot product, MLP + dot product without aggregation function (similar to SimCLR[11] but with BCE-focal loss) and MLP with aggregation function as suggested in the proposed method, see Table 3. Results show that the accuracy is also highly affected by the choice of the relation head architecture which is an essential component for the pipeline.
Table 3: Performance of different architectures instead of relation head.
Qualitative analysis: The performances of RotationNet and the proposed method on an image retrieval task are compared. 25 random queries (red frames, see Figure 3) are fed to the networks and the most similar images in representation space are given on the right side of the queries. The proposed method differs from RotationNet with better categorization of the samples which are challenging to separate (e.g. ships and planes in row 4). The proposed method has an outstanding performance on accuracy against RotationNet (see Figure 4), while reducing the scattering with more apparent clusters on representation space (see Figure 5). This leads to robustness against misclassification.
Figure 3: Image retrieval task on CIFAR-10 with ResNet-32. The query is the image in red frame, followed by the most similar images in representation space.
Figure 4: Image retrieval – error analysis with confusion matrix
Figure 5: Visualization of t-SNE embeddings[12] for 10K test points in CIFAR-10 using ResNet-32 trained by (a) supervised learning, (b) relational reasoning, (c) RotationNet.
Conclusion
The paper proposes a qualitatively and quantitatively effective representation learning approach through comparisons for downstream tasks including medical diagnostic tasks. The authors show that the proposed method improves the visual representations, which can be transferred to various domains and are fine-grained, using large amount of unlabeled data and outperforms against other state-of-art methods.
Student's Review
The paper demonstrates the superiority of the proposed method against other state-of-art methods. It compares the performance of different architectures with different settings such as performance on fine- and coarse-grained data, various types of backbones, relation head and aggregation function, performance with different number of augmentations etc. The qualitative analysis using t-SNE embeddings and confusion matrix visualizes the results well-structured and it provides evidence in a compact form regarding how effective the model is. Overall, the given tables and figures are very well explained for readers to understand the methods and settings which are used. Additionally, the authors draw attention on an important issue and warn the readers about gathering large amount of data from internet that can lead the model to biased predictions and may be dangerous for critical tasks.
However, due to many possible architectures, parameters and functions that can be used, many training processes are performed. For many of them ResNet32 backbone is used except for ablation of aggregation function (Conv4), finetuning on STL-10 dataset (ResNet34) and examination of the correlation between number of augmentations and the accuracy. I see this as a weak point of the paper in terms of the comparison that we are expected to make between different settings. Another issue is the qualitative analysis of proposed method with RotationNet. The authors mention SimCLR method as state-of-art method. Therefore, I would have expected a comparison of proposed method, SimCLR and RotationNet at the end.
References
[1] Schmidhuber, J. (1987). Evolutionary principles in self-referential learning, or on learning how to learn: The meta-meta-... hook. Diplomarbeit, Technische Universität München.
[2] Schmidhuber, J. (1990). Making the world differentiable: On using self-supervised fully recurrent neural networks for dynamic reinforcement learning and planning in non-stationary environments.
[3] Jing, L. and Tian, Y. (2020). Self-supervised visual feature learning with deep neural networks: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence.
[4] Gentner, D. and Kurtz, K. (2005). Relational categories. WK Ahn, RL Goldstone, BC Love, AB Markman, & PW Wolff (Eds.), pages 151–175.
[5] Lin, T.-Y., Goyal, P., Girshick, R., He, K., and Dollár, P. (2017). Focal loss for dense object detection. In International Conference on Computer Vision.
[6] Caron, M., Bojanowski, P., Joulin, A., and Douze, M. (2018). Deep clustering for unsupervised learning of visual features. In European Conference on Computer Vision.
[7] Gidaris, S., Singh, P., and Komodakis, N. (2018). Unsupervised representation learning by predicting image rotations. In International Conference on Learning Representations.
[8] Hjelm, R. D., Fedorov, A., Lavoie-Marchildon, S., Grewal, K., Bachman, P., Trischler, A., and Bengio, Y. (2019). Learning deep representations by mutual information estimation and maximization. In International Conference on Learning Representations.
[9] Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. (2020). A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709.
[10] Kolesnikov, A., Zhai, X., and Beyer, L. (2019). Revisiting self-supervised visual representation learning. In Computer Vision and Pattern Recognition.
[11] Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. (2020). A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709.
[12] Maaten, L. v. d. and Hinton, G. (2008). Visualizing data using t-sne. Journal of machine learning research, 9(Nov):2579–2605.