This is a blog post for the paper ‘MixMatch: A Holistic Approach to Semi-Supervised Learning’.
Written by David Berthelot, Nicholas Carlini, Ian Goodfellow, Avital Oliver, Nicolas Papernot and Colin Raffel
Introduction and Problem Statement
By leveraging large collections of labeled data, deep neural networks can achieve human-level performance. However, in practice creating large datasets with complete labels can be tedious, error-prone, and also expensive, especially in medical domains where expert knowledge is required.
We can alleviate the need for labels in the case of training a model with a small fully labeled dataset and large unlabeled dataset. This is known as Semi-Supervised learning.
Semi-Supervised learning (SSL) overview
Traditionally, there have been two fundamentally different types of tasks in machine learning[1]. Unsupervised learning that trains a model with no training data available and Supervised learning that trains a model with a fully labeled dataset. Semi-Supervised learning is an approach between Unsupervised and Supervised learning, represented in fig.1:
Figure 1. Types of learning[2]
The general objective behind any Semi-Supervised learning approach is to utilize unlabeled data as a regularizer so that the learner’s performance is improved. In order for any SSL approach to work, i.e make use of unlabeled data, certain assumptions need to hold.
Semi-Supervised Assumptions
The assumptions below make it possible to generalize from a finite labeled dataset to the unseen test data:
- Smoothness assumption: If two points x_1, x_2 in a high-density region are close, then so should be the corresponding outputs y_1, y_2[1].
- Cluster assumption: If points are in the same cluster, they are likely to be of the same class[1]. This is equivalent to saying that the decision boundary should lie in a low-density region.
- Manifold assumption: The data lie in a low-dimensional manifold to avoid the curse of dimensionality.
Most SSL approaches try to implement one of the above assumptions, for example, Consistency regularization and Entropy minimization:
- Consistency regularization uses data augmentation as a regularization technique, where each unlabeled data point should be classified the same as its augmentation.
Entropy minimization is a means to implement the Cluster assumption. The entropy is a measure of class overlap[8]. As class overlap decreases, the density of data points gets lower at the decision boundary.
The recent state of the art techniques in SSL: Π-Model[4], Mean Teacher[7], Virtual Adversarial Training (VAT)[5] and Pseudo-Label[6] to generalize better on unseen data have been using either one of the above-mentioned approaches or traditional regularization into their loss term.
In this paper a new SSL technique is introduced: MixMatch which compared to the other mentioned techniques unifies these dominant approaches: consistency regularization, entropy minimization, and traditional regularization, by targeting their properties all at once. This way MixMatch achieves state-of-the-art results on all four standard image datasets.
The high-level idea of MixMatch is to guess low-entropy labels for the augmented unlabeled data and apply further regularization by using MixUp in both labeled and unlabeled data.
Methodology
MixMatch utilizes each of the three approaches differently:
- Consistency regularization - By introducing Data Augmentation both in the labeled and unlabeled data
- Entropy minimization - By the use of Label Guessing and Sharpening in the unlabeled data
- Traditional regularization - MixUp as a modern regularizer that introduces a linear relationship between the data points
Now we will explain in detail these utilizations starting from Data Augmentation.
Data Augmentation
Data augmentation is a broadly used technique that utilizes consistency regularization, for example in the case of Π-Model[4] and Mean Teacher[7]. MixMatch uses standard data augmentation for images, such as random horizontal flips, crops, and rotations.
In MixMatch both the labeled and unlabeled data are augmented, however, K augmentations are done for each unlabeled data point:
\tilde x_b = Augment(x_b) \\ for \space k=1\space to \space K \\ \space \space \tilde u_{b,k} = Augment(u_b) \\end \space for |
The K augmentations done for each unlabeled data are used in the step of Label Guessing, which is described in the following part.
Label Guessing
In this step, a “guess” label \bar q_b is produced for each unlabeled data u_b by averaging the model predictions of K augmentations of u_b, like below:
\bar q_b = \frac {1}{K} \sum_{k=1}^{K}p_{model}(y|\space \tilde u_{b,k}; \theta) |
After Sharpening is computed for \bar q_b we obtain the final label for the unlabeled data point.
Sharpening
Figure 2. The effect of sharpening on randomly generated distribution[8]
In MixMatch, Sharpening is applied to the guessed label \bar q_b by using the equation below:
Sharpen(p, T)_i = p_i ^{\frac {1}{T}}/ \sum_{j=1} ^{L} p_j^{\frac {1}{T}} |
- p in our case is the predictions average over K augmentations \bar q_b; p = \bar q_b
- T is a hyperparameter
- L represents the numbers of classes
From fig.2 we see that as T goes toward 0, the outputs from Sharpen(p, T) will approach a one-hot encoding distribution.
MixMatch compared to other approaches, for example [5], doesn’t add an entropy term to the loss function to achieve Entropy Minimization, it just uses Sharpen(p, T). As a result, it encourages the model to output more confident predictions by moving the decision boundary from the data.
Figure 3. : Diagram of the label guessing process used in MixMatch.
Fig. 3 presents a great overview of the label guessing process in MixMatch. We can see how the steps from Data Augmentation to Sharpening are interconnected with each other in producing a label in the end.
MixUp
MixUp is a form of data-agnostic data augmentation that encourages the model to behave linearly in-between training examples [9]. This is achieved by performing a convex combination between two training samples and their corresponding labels:
\tilde x = \lambda x_i + (1-\lambda) x_j \\ \tilde y = \lambda y_i + (1-\lambda) y_j \\ \lambda \sim Beta(\alpha, \alpha); \alpha \in(0, \infty); \lambda \in [0, 1] |
\alphais a hyperparameter that needs to be tuned.
Figure 4: Visual comparison of different random mixing weight distributions. [10]
For small values of α < 1, the values λ, drawn from the Beta distribution are close to 0 \space or \space 1 where the effect of MixUp is smaller. On the right image, we see the intensity of MixUp is bigger since α > 1and the random λ values are near 0.5.
The paper, however, makes a small modification to MixUp by setting λ to max(\lambda, 1- \lambda):
\lambda \sim Beta(\alpha, \alpha) \\ \lambda ^{'} = max(\lambda, 1 - \lambda) \\ x ^{'}= \lambda^{'} x_i + (1-\lambda^{'}) x_j \\ y^{'} = \lambda^{'} y_i + (1-\lambda^{'}) y_j
This biases the MixUp to be closer to the original image since now x^{'} is closer tox_i than to x_j.
Before applying MixUp, augmented labeled examples and their labels are concatenated and shuffled with unlabeled examples and their guessed labels:
\tilde X = (( \tilde x_b, y_b); b \in (1, .. B)) \\ \tilde U = (( \tilde u_b, q_b); b \in (1, .. B)) \\ W = Shuffle(Concat(\tilde X, \tilde U)) |
MixUp is applied to both labeled data and unlabeled data mixed with the entries of W, as we can see below:
X' = (MixUp(\tilde X_i, W_i); i \in (1,... |\tilde X|) \\ U' = (MixUp(\tilde U_i, W_{i+ |\tilde X|}); i \in (1,... |\tilde U|) |
The reason why the authors' bias MixUp toward the original image is because separate losses are computed for the labeled and unlabeled data and an order needs to be preserved to compute the losses appropriately for X’ and U’, since W can contain labeled or unlabeled examples.
Loss Function
After MixUp, the batch of augmented labeled examples X’ and unlabeled examples U’ with their “guessed” labels, are computed in separate losses as we can see below:
X', U' = MixMatch(X, U, T, K, \alpha) \\L_x = \frac {1}{|X'|} \sum_{x, p \in X'} H(p, p_{model}(y|x; \theta)) \\L_u = \frac {1}{|U'|} \sum_{u, q \in U'} || q - p_{model}(y|u; \theta)||_{2} ^{2} |
For labeled data cross-entropy H(p,q) loss is used and the mean squared error for unlabeled data; d, T, K, α, and \space λ_u are hyperparameters.
Results and Conclusions
Experimental Setup
The “Wide ResNet-28” model from [11] is used in all of the experiments. For evaluation four standard benchmark datasets are used: CIFAR-10 and CIFAR-100, SVHN, and STL-10. Aiming practice for SSL is to treat most of the dataset as unlabeled and use a small portion as labeled data.
CIFAR-10 results
Fig. 5 represents the error rate results of MixMatch compared to other baselines. The supervised training means MixMatch is taking all the labels in CIFAR-10 and not treating any training example as unlabeled. In a fully supervised way, MixMatch achieves an error rate of 4.17%, whereas with only 4000 labels 6.24%.
From Table 1 we can observe that MixMatch with very few labels (250) can achieve promising results compared to the other baselines.
Figure 5: Error rate comparison of MixMatch to baseline
methods on CIFAR-10 for a varying number of labels.
Table 1: Error rate (%) for CIFAR10.
Ablation Study
Table 2: Ablation study on all values are error rates on CIFAR-10 with 250 or 4000 labels
This study is done to check which components used in MixMatch are showing good performance. For example, checking the case of using only a single augmentation (K=1) to the unlabeled data, or removing Sharpening (T = 1). Interpolation Consistency Training is also checked, which uses mixup only to unlabeled data, no Sharpening (T=1), and exponential moving average parameters are used for label guessing.
We can see from this study that Sharpening and MixUp play a very important role in the efficiency of MixMatch.
Privacy-Preserving Learning and Generalization
PATE[12] framework is used for privacy learning. From Table 3 a privacy loss \epsilon below 1 corresponds to a much stronger privacy guarantee which means fewer labels were used to achieve a fixed accuracy.
We can observe that MixMatch compared to VAT achieves even better test accuracy with fewer labels.
test accuracy | privacy loss | |
---|---|---|
VAT | 91.6% | ε = 4.96 |
MixMatch | 95.21 ± 0.17% | ε = 0.97 |
Table 3: Accuracy-privacy trade-off achieved by MixMatch compared to Virtual Adversarial Training (VAT) baseline on SVHN
Conclusion
In this paper, a new SSL approach is introduced called MixMatch, which combines and unifies dominant approaches in semi-supervised learning. As a result, it achieves significantly better performance than all the current SSL methods across many datasets by a significant factor. It also ensures a better accuracy-privacy tradeoff for differential privacy as it requires significantly lesser data than other methods to achieve similar performance.
Own Review
For those who are familiar with Semi-Supervised learning and its recent methods, this paper is very well structured and organized to read. But even for those who are not familiar with SSL, reading this paper is a good opportunity to get into the depth of this field and really understand the methods used in it.
When compared to other baselines, the authors are really fair since they try to further tune the baseline methods to get even better results from them. They also made the code open source for those who want to experiment.
As for weaknesses it deals with many hyperparameters that need to be tuned further which can cost time and resources. Also the authors in the Ablation Study part, they don't give any detailed explanation on why some part of MixMatch, for example, Mixup, plays a very important role in its performance. In my opinion, they miss on explaining the intuition behind these components rather than just showing results.
In future work, the MixMatch method has space to explore also other domains, for example, test its effectiveness in medical data. Also, explore how it can fit in domain adaptation and object detection.
References
[1]Semi-Supervised Learning O Chapelle, B Scholkopf, A ZienMIT Press, 2006
[2]
A fastai/Pytorch implementation of MixMatch