This blog post summarizes and reviews the paper “Exploring Simple Siamese Representation Learning” by Xinlei Chen and Kaiming He [1].

1. Introduction

1.1. Problem statement and Motivation

In the area of un-/self-supervised representation learning, Siamese network approaches have been common. These methods typically consist in the training of two identical subnetworks with the same configuration (parameters and weights) with the objective of learning a similarity function, instead of probabilities. They are commonly used to find the similarity of the inputs by comparing their feature vectors or “representations”.

It is known that for Siamese networks degenerate or collapsed solutions can appear, in which all outputs collapse to a constant. This collapse is attributed to a lack of a repulsive component in the optimization objective and different solutions have been proposed, such as the use of negative samples or a momentum encoder. The analysed paper proposes SimSiam, a simpler method, compared with the state-of-the-art, which avoids collapsing solutions.

1.2. Related Work

Three representation learning approaches are the basis for this paper:

  1. Contrastive learning. A popular methodology for un-/self-supervised representation learning based on the attraction between positive sample pairs and the repulse of negative sample pairs. The representative in the carried-out comparison is SimCLR [2], a simple architecture that learns representations by maximizing agreement between differently augmented views of the same date via contrastive loss. It focuses especially on the used augmentations, instead of on the negative. Another compared model is MoCo [3] which works with a momentum encoder and a queue of negative samples.
  2. Clustering. Unsupervised representation learning methods which alternate between clustering the representations and learning to predict the cluster assignments. In these methods, the cluster centres play as negative prototypes/samples. The proposed method is compared with SwAV [4], an online clustering method that solves a “swapped” problem, i.e., a problem where we predict the code of a view from the representation of another view.
  3. BYOL [4]. A method that does not use negative pairs and consists of two networks, one online network trained to predict the representations of a target network that consumes a different augmented view of the same image. It uses stop-gradient and momentum encoder to prevent collapse.

Figure 1 - Comparison of SimSiam with SimCLR, BYOL and SwAV. The components in red are those missing in SimSiam and the lack of dash line implies stop-gradient

2. Method

SimSiam, the proposed method, is simple in architectural terms (see Figure 1). While, as mentioned before, other state-of-the-art methods incorporate advanced elements to avoid collapsing solutions in the framework or in the architecture, SimSiam only incorporates the stop-gradient operation.

From an image x, two random augmented views are produced. These views are consumed by an encoder f, sharing for both inputs the weights, as usual. The encoder architecture consists of a ResNet-50 [6] backbone and a projection MLP head, a three-layers MLP with BN applied to each FC layer and a hidden FC of dim = 2048.

The encoder for the first input x_1 is followed by an MLP predictor h, a two-layers MLP with BN applied to its hidden FC layers, not in output FC, and a bottleneck structure with inputs and output of dim = 2048 and a hidden layer of dim = 512.

The output produced by the predictor p_1 is compared through a similarity function to the output z_1, produced by the encoder for the second input. Due to the stop-grad, z_2 is treated as a constant and the encoder does not receive gradient from z_2 on x_2It is going to receive the gradient from z_2 in a second term, i.e., when the inputs are swapped.

The minimized function is the negative cosine similarity, equivalent to the MSE of l_2-normalized vectors:

D(p_1, z_2) = - \frac{p_1}{\Vert p_1 \Vert_2} \cdot \frac{z_2}{\Vert z_2 \Vert_2}

This similarity leads us to the following symmetrized loss (incorporating stop-grad for the outputs z_i ) defined for each image, with a total loss averaged over all images:

\mathcal L = \frac{1}{2} D(p_1, stopgrad(z_2)) + \frac{1}{2} D(p_2, stopgrad(z_1))

The chosen optimization plan for the baseline consists of the use of SGD with a cosine decay schedule learning rate (base lr = 0.05), an SGD momentum of 0.9, weight decay of 0.0001, batch size of 512.

The authors also hypothesize about the intrinsic optimization problem in the model. It can be seen as an implementation of the Expectation-Maximization (EM) algorithm with two sets of variables and two subproblems to be solved. The stop-gradient operation would be a consequence of adding one more set of variables.

The optimization can be summarized as the minimization \min_{\theta, \eta} \mathcal L(\theta, \eta) with \mathcal L(\theta, \eta) = \mathbb E_{x, \mathcal T} [\Vert \mathcal F_\theta(\mathcal T(x)) - \eta_x \Vert^2_2] being \mathcal F a network, \mathcal T an augmentation and \theta and \eta the two sets of parameters. The problem can be solved as in k-means, with an alternating algorithm, fixing one set and solving for the other:

\theta^t \leftarrow \arg \min_\theta \mathcal L(\theta, \eta^{t-1})
\eta^t \leftarrow \arg \min_\eta \mathcal L(\theta^t, \eta)

For \theta the subproblem can be solved using SGD treating \eta^{t-1} as a constant, thus, stop gradients is natural. For \eta the subproblem is solved independently for each \eta_x. The symmetrisation has not been included in the hypothesis because it does not play a central role, beyond the little improvement on accuracy.

3. Experiments and Results

To see the contributions to the model that avoid collapsing solutions, an empirical study over the different architectural aspects has been done. The model has been also compared in different scenarios with the related models previously introduced.

3.1. Why SimSiam’s solutions do not collapse?

In the first plot of Figure 2 it can be appreciated how the model without stop-gradient rapidly finds a degenerated solution with an accuracy of 0.1, while using stop-gradient the model trains well. To show that the degenerated solution is caused by collapsing, the authors study, in the second plot, the standard deviation of the l_2-normalized solution output \frac{z}{\Vert z \Vert_2}, which should be zero in a collapsed case. They also show that, if the output z has a zero-mean isotropic Gaussian distribution, the standard deviation must be \frac{1}{\sqrt{d}} and the good solution is close to it.

Figure 2 - SimSiam with vs. without strop-gradient. Left plot: training loss. Middle plot: per-channel std of the -normalized output, plotted as the averaged std over all channels.

Right plot: validation accuracy of a kNN classifier. Table: ImageNet linear evaluation (“w/ stop-grad is meanstd over 5 trials).

For the predictor MLP , the model has been tested without a predictor, which does not work. Trying to fix random initialization neither works well because the predictor  needs to be trained.  The use of a learning rate not decayed or a cosine decay works similarly.

Regarding the batch size, while other methods (SimCLR and SwAV) require a large batch and specialized optimization methods like LARS [7], in Table 1 it can be seen that SimSiam works pretty well and equally for a range of different batch sizes. For size 4096, SGD worsens, but it has been shown that a specialized optimizer or large batch size is not necessary for preventing solutions.

Table 1 – Effect of batch sizes

In Table 2 and Table 3 it can be seen that neither the chosen similarity function nor the symmetrisation influences on collapsing solution. The use of cross-entropy similarity (D(p_1, z_2) = -softmax(z_2) \cdot \log softmax(p_1)) converge to a reasonable result without collapsing and the symmetrisation only helps to boost accuracy, but it has no impact on avoiding collapsing solution.  Different BN configurations have also been tested and they seem not to have an impact on collapsing solutions and the afore-described one is the best configuration.

Table 2 – Similarity function comparison

Table 3 – Types of symmetrisation comparison

It has been concluded that only the stop gradient has a real impact on preventing collapsing solutions.

3.2. Comparison

Compared to other previously mentioned more complex methods, all ones trained on ImageNet [8] and with ResNet-50 architecture, SimSiam outperforms or is close in terms of average to other options. The results are shown in Table 4.

Table 4 – Comparison on ImageNet linear classification. All are based on ResNet-50 pre-trained with two 224x224 views

It has been also shown that SimSiam works on par with other methods in transfer learning for tasks such as VOC (object detection) and COCO (object detection and instance segmentation), as seen in Table 5.

Table 5 – Transfer Learning

4. Conclusion

It has been shown how a simple Siamese architecture using only stop gradient can be on par or even improve other more complex models while avoiding collapsing solutions

5. Student’s Review

5.1. Strengths

The main strength is the achieved simple model and the interesting comparison with state-of-the-art models. Sometimes in Deep Learning the tendency is to make big and complex models and this paper shows how different advanced models can be analysed and reduced to a basic common point that works appropriately.

5.2. Weaknesses

The main weakness is that the collapsing prevention has been only shown through empirical observations, without theoretical proof.

Another weakness is that the comparison has been carried out only with ImageNet and it could be interesting to see the performance in other image datasets, as well as in another type of data such as text.

5.3. Suggestion for Future Work

Siamese networks are widely used also in NLP for representation learning (embeddings) and similarity learning. It could be interesting to see how the acquired improvements apply in NLP.

5.4. Application to medical data

The achieved model could be used in medical data applications in which data representations are required, as a basic model or baseline. An example might be obtaining image representations to compare while searching for similar ones in an image bank.

References

[1] Chen, X., & He, K. (2021). Exploring simple siamese representation learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 15750-15758). 

[2] Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A simple framework for contrastive learning of visual representations. In International conference on machine learning (pp. 1597-1607). PMLR.

[3] He, K., Fan, H., Wu, Y., Xie, S., & Girshick, R. (2020). Momentum contrast for unsupervised visual representation learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 9729-9738).

[4] Caron, M., Misra, I., Mairal, J., Goyal, P., Bojanowski, P., & Joulin, A. (2020). Unsupervised learning of visual features by contrasting cluster assignments. arXiv:2006.09882.

[5] Grill, J. B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., ... & Valko, M. (2020). Bootstrap your own latent: A new approach to self-supervised learning. arXiv:2006.07733.

[6] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

[7] You, Y., Gitman, I., & Ginsburg, B. (2017). Large batch training of convolutional networks. arXiv:1708.03888.

[8] Deng, J., Dong, W., Socher, R., Li, L. J., Li, K., & Fei-Fei, L. (2009). Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition (pp. 248-255).

[9] Chopra, S., Hadsell, R., & LeCun, Y. (2005, June). Learning a similarity metric discriminatively, with application to face verification. In 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05) (Vol. 1, pp. 539-546). IEEE.

  • Keine Stichwörter