This is a blog post about the paper 'Search for Better Students to Learn Distilled Knowledge' by Jindong Gu and Volker Tresp.
Intorduction
Problem Statement and Motivation
Deep learning methods have dominated lots of computer vision tasks, including image classification, image detection, etc. All existing impressive breakthrougs lead us to consider: Why deep neural networks are powerful? One of the important reasons is over-parameterization. However, over-parameterization leads to large models, which require huge amount of computations and time to accomplish inferences. And therefore prevent their applications and deployments in time-sensitive or memory-limited environments.
A question is therefore raised naturally: How can we speed up inferences? There are numerous works aims to accelerate inference processes. Two of them are Network Pruning [1] and Knowledge Distillation [2]. Network pruning removes redundant features or connections, thus contributing to the reduction of model size and inference time. Knowledge distillation compresses a large model to a smaller one. Typically, the distilled smaller model has better performance than that trained from scratch. In this paper, in order to search a student, the authors combine network pruning and knowledge distillation and regard network pruning as an architecture search approach to search student model that could learn the distilled knowledge better.
Network Pruning
In network pruning, trivial connections, weights can be removed to reduce FLOPs and the number of parameters. As shown in Figure 1 [1], the second part is the kernel matrix, each small square stands for the convolutional kernel, which has two dimensions, and each column actually represents a 3D-convolutional filter. The operation for one convolutional filter is simply applying each column of the kernel matrix to the input 3D feature map and end up with a 2D feature map, because this is just for one convolutional filter. Accordingly, pruning a filter will remove its corresponding feature map and its related kernel in the next layer. As shown in the next kernel matrix, all convolutional filters will ignore the removed feature map.
Figure 1: Pruning a filter results in removal of its corresponding feature map and related kernels in the next layer.
Knowledge Distillation
Pioneered by the work from [2] and [3], knowledge distillation aims to compress or distill the knowledge from a powerful teacher network to a smaller student network, which usually contain less layers, etc. Typically, the student trained with distilled knowledge performs better than that trained from scratch.
Specifcally, the distillation consists of two steps. Firstly, training a well-performed (teacher) network, such as DenseNet[4] with $\tau = 1$ , $a_t$ stands for logits from teacher network and $\tau$ is the temperature.
L_{CE} = CE(softmax(a_t / \tau), targets) |
Secondly, generating soft-targets. Simply do the inference pass with the the raised temperature, i.e., \tau>1 on training set or another transfer set. This reformulation generates a softer probability distribution over classes, especially for higher \tau.
\[ soft\underline{\hspace{.1in}}targets = softmax(a_t / \tau) \] |
In the last step, training the student network with soft targets at the same high temperature, i.e. minimizing the follow cross-entropy loss function.
\[ L_{CE} = CE(a_s / \tau, soft\underline{\hspace{.1in}}targets)\] |
Methodology
The student architecture for distillation is usually manually configured or chosen from SOTA architectures but with smaller size. This might not be a good choice to learn distilled knowledge. How can we learn a somehow customized student that can learn the distilled knowledge better? The author proposed to combine network pruning and knowledge distillation into one step. Specifically, the network pruning is chosen to be the approach to search students. Furthermore, the search process is aware of knowledge distillation, so that finally found student can learn the distilled knowledge better.
Student Architecture Search
In student architecture search, a gate is assigned to each channel by multiplying the feature map of that channel by a scaling factor $g$. After optimization, the channel with open gate $(g \neq 0)$ will be retained and that with closed gate, i.e. $(g=0)$ can be removed safely. As show in the Figure 2.
Figure 2: The illustration of the proposed approach
g beacomes a K-element vector for a layer with K channels in a teacher neural network. This method can easily be applied to SOTA neural networks, such as DenseNet[4]. Here is an illustraion of the third layer in a dense block.
Figure 3: An example of student architecture search on the third layer of a dense block.
In order to find a student that could learn the distilled knowledge better, the author propose distillation-aware loss funcution to guide the student search process. Given an input $\mathbf{x_i}$ , the softened output of a teacher model $f_t(\mathbf{x_i})$ and of the constructed model by adding gatesf_s(\mathbf{x_i,w,g}) , the loss is define as:
\[ \min_{\mathbf{w, g}} \frac{1}{N} \sum_{i=1}^N KL(f_s(\mathbf{x_i,w,g}), f_t(\mathbf{x_i)}) + \lambda_1||{\mathbf{w}}||_2 + \lambda_2\sum_{j=1}^M \alpha_j||{g_j}||_1 \] |
The first term is the distillation term, which matches the softened outputs. The second term is just weight regularization. And the third term is L1 norm on g. L1 norm enforces sparcity, which is what we want. We want to reduce the model size. $\alpha_j$is the scalling factor on a gate, defined as a_j = \frac{F_j}{max_{k=1}^M F_k} where $F_j$ is the saved FLOPs when $g_j$ is closed. M is the total number of gates. This scaling will let the loss attend more on the gates whose removal bring more reduction on FLOPs.
Optimization
Proximal Gradient Descent
$\mathbf{w}$ can be updated by normal gradient descent. However, for $\mathbf{g}$, the third term of the loss function is not differentiable. Therefore, it cannot be optimized by normal gradient descent. Other method is required to update $\mathbf{g}$, which is called proximal methods [5]. The proximal operator is defined as:
\[ \mathbf{prox}_{h,t}(x) = \min_z \frac{1}{2t}||{z-x}||_2^2 + h(z)\] |
This is called proximal operator of function $h$ with parameter $t$, where h(x) is not necessarily differentiable. In case $h(x) = \lambda||{x}||_1$, its proximal operator $S_{\lambda t}$ is called soft-threshold, defined as:
\[[S_{\lambda}(x)]_i =\begin{cases} x_i-\lambda, & \text{if}\ x_i > \lambda \\ 0, & \text{if}\ -\lambda \leq x_i \leq \lambda,\ i=1,\dots,n \\ x_i+\lambda, & \text{if}\ x_i < -\lambda \end{cases}\] |
Since the non-differentiable term h(x) is usually quite simple, so the soft-threshold operator can be derived analytically. Equipped with proximal operator, one can optimize a composite function f(x) =g(x) +h(x) where g(x) is differentiable but h(x) maybe not. The update steps are the following:
\[ x^{(k)} = \mathbf{prox}_{h,t_k}(x^{(k-1)}-t_k\nabla g(x^{k-1})), k=1,2,3,\dots\] |
Obviously, the gradients are evaluated only for function g(x), which is differentiable. For the distillation-aware loss function, g(x) is the first KL-divergen term $g(x) =\frac{1}{N}\sum_{i=1}^{N}KL(f_s(\mathbf{x_i,w,g},f_t(x_i)))$ and h(x) is the L1 normalization term, $h(x) = \lambda_2\sum_{j=1}^M \alpha_j||{g_j}||_1$. In this way, we can optimize for parameter \mathbf{g}.
Accelerated Proximal Gradient Descent
There is an accelerated version of Proximal gradient descent as follows, which is an adaptation of Nesterov Accelerated Gradient(NAG)[6]. The update steps are: Choose initial point $x^{(0)} = x^{(-1)} \in \Re^n$ and repeat $k=1,2,3\dots$
\begin{align*} v &= x^{(k-1)} + \frac{k-2}{k+1}(x^{(k-1)} - x^{(k-2)}) \\ x^{(k)} &= \mathbf{prox}_{t_{k}}(v - t_k \nabla g(v)) \end{align*} |
It adds momentum to speedup the convergence. However, this requires additional forward and backward to evaluate gradients of g(x) at v, which is not acceptable if g(x) is a large CNN architecture.
Modified NAG and Application on APG
Bengio and Sutskever reformulate the nesterov accelerated gradient, which circumvents the additional forward-backward pass. Details are in [7] [8]. Following their derivations, the optimization steps are now the following:
\begin{align*} \mathbf{z}^{(k)} &= \mathbf{g}^{(k-1)} - \eta \cdot \nabla l(\mathbf{g}^{(k-1)}) \\ \mathbf{v}^{(k)} &= S_{\lambda_2\eta}(\mathbf{z}^{(k)}) - \mathbf{g}^{(k-1)} + \mu^{(k-1)} \cdot \mathbf{v}^{(k-1)} \\ \mathbf{g}^{(k)} &= S_{\lambda_2\eta}(\mathbf{z}^{(k)}) + \mu^{k} \cdot \mathbf{v}^{(k)} \end{align*} |
$l(\mathbf{g})=\frac{1}{N}\sum_{i=1}^{N}KL(f_s(\mathbf{x_i,w,g},f_t(x_i)))$, $\eta$ is gradient step size and \mu is momentum coefficient with $\mu^{(k-1)} = \frac{k-2}{k+1}$.
Experiments
Dataset & Baseline Models
CIFAR10 and CIFAR100 [8] are used for training and evalution. DenseNet(L=100, k=12) is adopted as the large teacher models in all experiments, which has 296M FLOPs. Different student models are choosen to verify the effectiveness of the proposed method. Specifically, the student architectures are either chosen from manual configuration or existing popular architectures. All student architectures are specified with around 90M FLOPs or 0.8M paramerters.
Training Strategies
For training, weight decay is set to 1e-4 for \mathbf{w} and all models are trained for 300 epochs with batch size 128. The learning rated is 0.1 and is multiplied by 0.1 at 150th and 255th epoch. $\lambda_2 = 1e^{-3}$ is used during student search. When students are trained under distillation,\tau = 4, \lambda=0.1. Early stopping is adopoted when a student network with 90M FLOPs is found.
Results Interpretation
Results
Figure 4: The performance of all students on the two datasets. Difference colors indicate diffrent students.
The cross marks means the model trained under distillation while the circle marks mean the model trained from scratch.
All student models are trained with teacher DenseNet(100, 12)
As you can see in Figure 4 (a), the found architecture contains less parameters with better result on CIFAR10 and CIFAR100. Figure 4 (b) compares test accuracy on CIFAR10 and CIFAR100 in terms of FLOPs. The result on CIFAR10 is pretty good, it achieved better performance with the least FLOPs. However, on the CIFAR100, though the FLOPs is the least, its accuracy is surpassed by some other students. But those models behave poor on CIFAR10. Therefore, the authors concluded that no student architecture can win under all datasets.
Ablation study on loss function
In this section, the distillation-aware loss function is dissected to investigate how the individual component influence the found student and its performance. There are two questions to be answered. The first is whether it is important to use KL-divergence to match softened outputs of the teacher model and the student model? Which means whether the distillation term contributes to the searching. And the second one is whether the scaling on gates is important? For the first question, the distillation term of the proposed loss (KD) is simply replaced by cross-entropy term with non-raised temperature, resulting a network pruning loss(NP). KD and NOKD in the first row stand for students trained under distillation or not.
Figure 5: Performance of models selected under different loss functions on CIFAR10 and CIFAR100(Test Error %)
Figure 5 reveals the answers. By comparing the second row with the last row, it is obvious that the maching term is important. And by comparing the 1st with 2nd row or 3rd row with 4th row, we can conclude that the scaling factor \alpha indeed plays a role in the proposed loss function.
In order to find the relation between softened outputs mathing term and performance of students. Instead of forcing the matching term to zero, the authors propose to force it to kl_0, which is a manually defined value, and vary it to see how it will impact the accuracy. The loss function now becomes:
\[ \min_{\mathbf{w, g}} \frac{1}{N} \sum_{i=1}^N abs(KL(f_s(\mathbf{x_i,w,g}), f_t(\mathbf{x_i)})-kl_0) + R_1 + R_2 \] |
Figure 6: The correlation between KL-divergence and the distillation performance of the selected student
Shown by Figure 6, with increasing kl_0, the accuracy on the test set drops. And the correlation between kl_0 and test error is 0.88, meaning that the first KL-Divergence term defines what a good student is.
Sensitivity Analysis
$\lambda_2$ in the loss influences how long it takes to reach the student network with specified FLOPs. In this analysis, it records how many epochs is needed to find student architecture with 90M FLOPs. Figure 7 shows that with the increase of $\lambda_2$, the time it takes to find student architecture with specified FLOPs is becoming shorter, indicated by the green line. However, for the test error, it increase as well. This is because large $\lambda_2$ enforces dropping more feature maps and connections. But in the later stage of the network, almost all feature maps are necessary due to its generation of more representative features.Therefore, high value of $\lambda_2$ negatively impacts the performance, while leading to quick update to the desired student.
Figure 7: Sensitivity analysis of $\lambda_2$ on trained epochs and performance.
Visualization of Students
Here is an visualization of the found student architectures for 3 dense blocks in a network. Each colored small square represents a connection, and the color intensity indicates the drop ratio, which is the number of removed channels divided by the number of channels before removal. An interesting observation is that in the third dense block, almost no connection is remove. This is beacuse feature maps generated by the later part of the network is more representative and can express complicated patterns.
Figure 8: Visualization of a found student with three dense blocks, each dense block is represented by one picture
Conclusions
The authors proposed an approach to search a student that can better learn the distilled knowlege. They regarded network pruning as the search approach and designed distillation-awared loss function to guide the search process. The found student performs better than that from manually configuration or existing mainstream architectures. It also performs better even if trained from scratch.
Own View
Many advantages exist in this work. The idea of network pruning and knowledge distillation are integrated into one step, which results in the method borrow advantages from both sides. A novel loss is proposed to learn a student architecture. Since the proposed loss is aware of distillation process, leading to the searched student which can learn distilled knowledge better. Also, the optimization uses an modified version of NAG combined with proximal operator, accelearting the convergence speed. Last but not the least, regarding network pruning as a way to do architecture is much more resource-saving and time-saving compared to mainstream neural architecture search (NAS) approach, as it usuallly costs hundreds of GPUs trained for many days.
However, the weakness is also apparent: The possibilities of students are limited by the teacher model. As network pruning is applied on teacher model for student architecture search, the found student model is always the subgraph of the teacher model. Since the teacher model is pre-selected, its number of subgraphs is also fixed. Therefore, the proposed method lacks flexibility in searching architectures.
For possible future works, I think an obvious one, addressing the weakness mentioned previously, is to search a more flexible student. Neural architecture search (NAS) allow to find more flexible architectures. A possible approach would be using NAS for architecture search while designing a new loss to limit the size of the model size. Another future work that is worthy working is from the distillation side. In this work, only the softened label information is used, but more knowledge can be distilled from the teacher, such as the architecture knowledge of the teacher model, which might be helpful.
References
[1] Li, Hao, et al. "Pruning filters for efficient convnets." arXiv preprint arXiv:1608.08710 (2016).
[2] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 (2015).
[3] Buciluǎ, Cristian, Rich Caruana, and Alexandru Niculescu-Mizil. "Model compression." Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining. 2006.
[4] Huang, Gao, et al. "Densely connected convolutional networks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
[5] Parikh, Neal, and Stephen Boyd. "Proximal algorithms." Foundations and Trends in optimization 1.3 (2014): 127-239.
[6] Sutskever, Ilya, et al. "On the importance of initialization and momentum in deep learning." International conference on machine learning. 2013.
[7] Bengio, Yoshua, Nicolas Boulanger-Lewandowski, and Razvan Pascanu. "Advances in optimizing recurrent networks." 2013 IEEE International Conference on Acoustics, Speech and Signal Processing. IEEE, 2013.
[8] Huang, Zehao, and Naiyan Wang. "Data-driven sparse structure selection for deep neural networks." Proceedings of the European conference on computer vision (ECCV). 2018.