1. Introduction

The performance of deep learning methods on vision tasks increases logarithmically based on the size of the training data [1]. However, deep learning methods lack labelled data in the medical domain [2, 3]. Meta-learning is an excellent way to alleviate this problem, as it can learn from few-shot images. Unlike traditional deep learning models, which train and test for the same task, meta-learning models learn on many different training tasks and test on other target tasks. Each task can only contain a few samples. The model can transfer its knowledge to other target tasks with few training data. This is similar to the human learning process. Who have seen bears before can learn how to recognize a panda with a few examples, as panda is bear-like. The knowlege comes from learning recognizing a bear is important. Therefore, the previous tasks in meta-learning are essential. The way how we define the previous tasks is called task modelling.

This blog concentrates on task modelling in meta-learning and its application in the medical domain. Section 2 introduces some basic meta-learning models in detail; section 3 classifies state-of-the-art task modelling techniques and gives examples of their application in the medical domain; section 4 compares those techniques and shows how meta-learning can help in medical application; the last section concludes this blog and proposes future works. 

2. Background

This section introduces a basic meta-learning model to illustrate the difference between conventional deep learning and meta-learning and explains why it can work in a few-shot scenario.

Model-agnostic meta-learning (MAML) [4] is a fundamental method in meta-learning. In contrast to conventional deep learning methods, MAML does not give a model that can finish the target task directly but gives a initial model can be finetuned using only a few data to achieve good performance. The training and testing data are not input-output pairs but a lot of training and testing tasks where each task contains a lot of input-output pairs. The key assumption of MAML is that there is a common backbone structure for all the tasks in the task distribution that can be fast finetuned for all the tasks. To find this commen structure, the gradient descent process should not only take one task loss into account but cares about the loss averaged from every task loss. Figure 1 shows the MAML algorithm. In each loop, some tasks are sampled from the task pool. The dataset of each task \mathcal{T}_i consists of a support set (used for training) and a query set (used for validation). For each task \mathcal{T}_i, the gradient for current weights \theta is computed based on the support set and desired updated weights \theta_i' are computed as well. Then, we compute the loss \mathcal{L}_i on the query set for each task using corresponding desired weights \theta_i' and average the loss over all tasks. The gradient of the averaged loss for current weights shows the descent direction, along which the losses of nearly all tasks are decreased. After executing this process several times, we can get a model which leads to a low loss on each task. For a single task sampled from the distribution, the loss function is very close to the local minimum so that it can be finetuned with only a few training data in the support set of the testing task. In the testing phase, we need to finetune the meta-model using the support set of the testing task at first and test the finetuned model with the query set of each testing task to find the performance of the meta-model.

Figure 1. Model-agnostic meta learniong (MAML) algorithm

Figure 1. Model-agnostic meta learning (MAML) algorithm [4].

Meta-learning is usually used for few-shot learning. A representative example of few-shot learning is N-way, K-shot classification. In both meta-learning and meta-testing phases, the algorithm samples N categories and K images per category to form the support set, and the rest images in the categories form the query set. This two sets form a task. The categories in the test task do not appear in the training task so that we can test if the meta-model can transfer its knowledge to classify new classes with only a few labelled data in the support set of the test task. 

3. Task Modelling in Meta-learning

Meta-learning learns from many tasks. How should we define the tasks? The following subsections give some state-of-the-art methods for defining the tasks to achieve good performance in meta-learning.

Task2Vec: Comparing Task Similarity

An intuitive understanding of neural networks is that the networks extract features from the input data and give those features weights based on how those features affect the prediction. If we want to transfer the knowledge from other tasks to our target tasks, it is better to learn the knowledge based on some tasks has similar features as the target task. Guan and Lu have proved the importance of task similarity by defining the generation bound for meta-learning [5].

Task2Vec [6] generates vectors for each task and calculates the distance between two vectors to measure the similarity of the correspondent two tasks. The key idea is that not all weights are equally crucial to the prediction. Therefore, if we give the weights perturbations w^{\prime}=w+\delta w, the output will change dramatically if those weights are essential. A typical way to compare the outputs distribution p_{w}(y \mid x) and p_{w^{\prime}}(y \mid x) is using Kullbach-Leibler (KL) divergence. To second-order approximation, this is 

\mathbb{E}_{x \sim \hat{p}} K L\left(p_{w^{\prime}}(y \mid x) \| p_{w}(y \mid x)\right)=\delta w^{T} \cdot F \cdot \delta w+o\left(\delta w^{2}\right),

where F is the Fisher information matrix (FIM): 

F=\mathbb{E}_{x, y \sim \hat{p}(x) p_{w}(y \mid x)}\left[\nabla_{w} \log p_{w}(y \mid x) \nabla_{w} \log p_{w}(y \mid x)^{T}\right].

If w is highly decisive for the prediction, the feature corresponding to w is important for the prediction and the corresponding entry in F will be large. Therefore, we can use FIM to encode the inputs' features and thus encode the task. As for different networks, the FIM value will be different, they use a probe network pre-trained on ImageNet as the feature extractor and re-train only the classifier layer on any given task to obtain FIM. 

Godau and Maier-Hein [7] conducted experiments using Task2Vec to compare tasks in biomedical image analysis. Figure 2 shows the results. The distances between tasks in the same domain are significantly small than tasks in other domains. After a series of experiments, they concluded that Task2Vec method can capture relationships between tasks and can be used to select training tasks in meta-learning.

Figure 2. Comparing tasks' distance using Task2Vec [7].

Task Interpolation: Task Augmentation

Data augmentation has been proven as an effective way of regularization. Yao et al. [8] have proposed meta-learning with task interpolation (MLTI) to augment tasks. Given a pair of tasks with their support set and query sets (i.e. \mathcal{T}_{i}=\left\{\mathcal{D}_{i}^{s}, \mathcal{D}_{i}^{q}\right\} and \mathcal{T}_{j}=\left\{\mathcal{D}_{j}^{s}, \mathcal{D}_{j}^{q}\right\}), MLTI selects one layer l and extract the output of this layer as hidden representations \left(\mathbf{H}_{i}^{s(q), l}, \mathbf{H}_{j}^{s(q), l}\right). Then, it interpolates using Manifold Mixup [9] separately on hidden representations \left(\mathbf{H}_{i}^{s(q), l}, \mathbf{H}_{j}^{s(q), l}\right) and corresponding labels \left(\mathbf{Y}_{i}^{s(q)}, \mathbf{Y}_{j}^{s(q)}\right):

\begin{array}{cc} \tilde{\mathbf{H}}_{c r}^{s, l}=\lambda \mathbf{H}_{i}^{s, l}+(1-\lambda) \mathbf{H}_{j}^{s, l}, & \tilde{\mathbf{Y}}_{c r}^{s, l}=\lambda \mathbf{Y}_{i}^{s}+(1-\lambda) \mathbf{Y}_{j}^{s}, \\ \tilde{\mathbf{H}}_{c r}^{q, l}=\lambda \mathbf{H}_{i}^{q, l}+(1-\lambda) \mathbf{H}_{j}^{q, l}, & \tilde{\mathbf{Y}}_{c r}^{q, l}=\lambda \mathbf{Y}_{i}^{q}+(1-\lambda) \mathbf{Y}_{j}^{q}, \end{array}

where \lambda \in[0,1] is sampled from a Beta distribution and "cr" is the abbreviation of "cross". Then, replace the output of layer l with the interpolated support sets \mathcal{D}_{i, c r}^{s}=\left(\tilde{\mathbf{H}}_{i, c r}^{s, l}, \tilde{\mathbf{Y}}_{i, c r}^{s}\right) and query sets \mathcal{D}_{i, c r}^{q}=\left(\tilde{\mathbf{H}}_{i, c r}^{q, l}, \tilde{\mathbf{Y}}_{i, c r}^{q}\right) for training the rest network after the layer l. MLTI algorithm supports both inter-task and intra-task interpolation, namely i=j is allowed. The interpolation strategy Manifold Mixup can be replaced by different task interpolation methods, e.g., CutMix [11]. The authors have proven mathematically and practically that MLTI induces a regularization term on the loss function of meta-learning, thus leading to a better generation bound. 

Singh et al. [12] add a task augmentation algorithm to a meta-learning algorithm Reptile [13] and test the model performance on medical datasets, including the Breast cancer Histopathological Image (BreakHis) dataset [14], SIC 2018 Skin Lesion dataset [15] and The Pap-smear bench-mark dataset [16]. Instead of interpolation using hidden representations, they interpolate between input images directly. They compared the performance achieved with regular augmentation (rotation and flip), Manifold Mixup and CutMix. CutMix means cutting out a section of an image and pasting it into another image in the training dataset and adapting the ground truth labels proportionally to the area of the cut sections (i.e., \tilde{\mathbf{Y}}_{c r}=\lambda \mathbf{Y}_{i}+(1-\lambda) \mathbf{Y}_{j}, where \lambda is the proportion of pasting area to total area). Their experiments show that task interpolation can contribute up to about 5% accuracy compared with regular augmentation.

Unsupervised Meta-learning: Generating Tasks 

Until now, most meta-learning algorithms are supervised learning. Therefore, if we want to use unlabelled data, we first need to generate pseudo-labels for the data. Hsu et al. [17] proposed a method that uses clustering methods, i.e., assigns the data in the same cluster to the same class and shows the usability of using these pseudo-labels to train a meta-model. Maicas et al. utilize a similar method for medical image classification. Figure 3 shows the entire process of their method. They use deep clustering [18] to generate a set of clusters C = \{c_i \mid i =1,...,K\} based on their features. i clusters are sampled randomly and labelled as class 0; j clusters are sampled from the rest K-i clusters and labeled as class 1, which formed a binary classification problem. Therefore, This method can generate L=\sum_{i=1}^{K-1} \sum_{j=1}^{\min (i, K-i)} \frac{\left(\begin{array}{c}K \\i\end{array}\right) \times\left(\begin{array}{c}K-i \\j\end{array}\right)}{1+\delta(i-j)} tasks. These tasks are used for meta-training. As the data of target tasks and training tasks using the images comes from the same dataset, the meta-model can extract features from the images and thus achieve good performance after training with few labelled data. They evaluated their method on the breast screening classification task from a breast DCE-MRI data set [19] and conclude this approach performs significantly better than other unsupervised and supervised pre-training methods and is competitive with supervised meta-training.

Figure 3. Unsupervised task design to meta-train medical image classifiers [17].

Task Scheduler: Machine Learning on Task Sampling

As the MAML algorithm needs to sample some tasks at each iteration, and not all tasks are equally important for improving meta-models performance, machine learning techniques can be used for task sampling to find the best tasks.

Yao et al. proposed an adaptive task scheduler (ATS) [20] to deal with noisy data. Their method is a reinforcement learning method. The ATS is the agent, and the meta-training algorithm corresponds to the environment in the reinforcement learning frame. Their algorithm is as shown in Figure 4. At each iteration, the ATS gives probabilities of sampling each tasks based on the loss \mathcal{L}\left(\mathcal{D}_{i}^{q} \; \theta_{i}^{(k)}\right) and the gradient similarity \left\langle\nabla_{\theta_{0}^{(k)}} \mathcal{L}\left(\mathcal{D}_{i}^{s} ; \theta_{0}^{(k)}\right), \nabla_{\theta_{0}^{(k)}} \mathcal{L}\left(\mathcal{D}_{i}^{q} ; \theta_{0}^{(k)}\right)\right\rangle\mathcal{L}\left(\mathcal{D}_{i}^{q} \; \theta_{i}^{(k)}\right) is the loss on query set \mathcal{D}_{i}^{q} of task \mathcal{T}_{i} using the meta-model with parameter \theta_{i}^{(k)} getting by performing a few gradient steps start from meta-model with parameter \theta_{0}^{(k)} at current meta-training iteration. \left\langle\nabla_{\theta_{0}^{(k)}} \mathcal{L}\left(\mathcal{D}_{i}^{s} ; \theta_{0}^{(k)}\right), \nabla_{\theta_{0}^{(k)}} \mathcal{L}\left(\mathcal{D}_{i}^{q} ; \theta_{0}^{(k)}\right)\right\rangle is the gradient similarity between the support and target sets with respect to the current meta-model θ^{(k)}_0. Based on these probabilities, the meta-training algorithm samples tasks for updating the meta-model, and calculate accuracy on the validating tasks using the updated meta-model. This accuracy is used as reward to update the ATS. Their theoretical Analysis shows that the ATS tends to give high sampling probabilities to tasks with high gradient similarity as the generalization gap from the support set D^s_i to the query set D_i^q is small (low noise in the data), and prefers to give low sampling probabilities to tasks with a high loss on the query set as the data of this task might be noisy. ATS improves the performance on a noisy drug discovery benchmark by up to 18% compared to some state-of-the-art task schedulers.

Figure 4. Meta-training Process with ATS [20].

This is not the only case that applies neural networks to task modelling. Kaddour [21] et al. and Wang et al. [22] combine active learning and meta-learning to improve data-efficiency using the probabilistic model and Markov decision process respectively. All the above examples show the potential of using more machine learning on task modelling.

4. Discussion

Comparing Methods

in conventional deep learning we should take care of the following: 1) The training data and test data should have similar feature distribution; 2) Regularization is needed to avoid overfitting. The training tasks in meta-learning are compared with training data in conventional deep learning. Therefore, we should also consider the task modelling from the above perspectives. Task2Vec gives us an indicator to judge if tasks are similar enough by judging if the same features are essential in different tasks. Task interpolation augments the tasks to regularize the meta-model. These two techniques can be used in training every meta-model and ensure the models' performance to a certain extent.

Moreover, in conventional deep learning, we lack labelled data sometimes. Unsupervised learning provides the opportunity to generate labels. Unsupervised meta-learning also enables task generation. All tasks are based on the same dataset that guarantees all the tasks are similar. However, unsupervised meta-learning performs a bit worse than supervised meta-training. We can combine supervised and unsupervised meta-learning to model similar tasks and improve the model's performance.

After building a candidate task pooling using the above methods, a machine learning empowered task scheduler could be used for task sampling that gives the probability of using each task for training so that we can fully use all the information provided by the tasks.

The above methods build a pipeline for task modelling. We first find similar tasks to create a task pool and then use task augmentation methods to generate more tasks for better generalization. We can use unsupervised methods to create tasks if it still lacks tasks. During the meta-training phase, we can use a machine learning empowered task scheduler for task sampling.

Meta-learning in Medical Domain

The challenges in medical image analysis are lack of large training data sets and labelled data, class imbalance, noisy data, and multi-modality [23]. These challenges can be solved by meta-learning. As meta-learning only needs a few data for training, we do not need to care about the lack of data and the class imbalance; For noisy data, we can use intelligent task scheduling to find the best one in noisy data; Multi-modality requires the model train on one modality can be fast finetuned for another modality, while the one property of meta-model is that it can be fast finetuned for new tasks. Therefore, meta-learning is a good solution for medical image analysis. 

5. Conclusion

This blog first introduces meta-learning and then compares some state-of-the-art methods for task modelling. The properties of meta-learning make it a good solution for medical image analysis. Future work in the field of task modelling in meta-learning can combine different task modelling methods to define a standard pipeline, including finding similar tasks, task generation, task scheduler, etc. If this pipeline can be implemented intelligently, making an algorithm that can execute the pipeline automatically, we can dream that AIs can learn to transfer knowledge from previous tasks to new tasks and thus create universal robots that are good at anything. 

Reference

[1]: Sun C, Shrivastava A, Singh S, et al. Revisiting unreasonable effectiveness of data in deep learning era[C]//Proceedings of the IEEE international conference on computer vision. 2017: 843-852.

[2] Piccialli F, Di Somma V, Giampaolo F, et al. A survey on deep learning in medicine: Why, how and when?[J]. Information Fusion, 2021, 66: 111-137.

[3] Litjens G, Kooi T, Bejnordi B E, et al. A survey on deep learning in medical image analysis[J]. Medical image analysis, 2017, 42: 60-88.

[4] Finn C, Xu K, Levine S. Probabilistic model-agnostic meta-learning[J]. Advances in neural information processing systems, 2018, 31.

[5] Guan J, Lu Z. Task Relatedness-Based Generalization Bounds for Meta Learning[C]//International Conference on Learning Representations. 2022.

[6] Achille A, Lam M, Tewari R, et al. Task2vec: Task embedding for meta-learning[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2019: 6430-6439.

[7] Godau P, Maier-Hein L. Task Fingerprinting for Meta Learning in Biomedical Image Analysis[C]//International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2021: 436-446.

[8] Yao H, Zhang L, Finn C. Meta-Learning with Fewer Tasks through Task Interpolation[C]//International Conference on Learning Representations. 2022.

[9] Verma V, Lamb A, Beckham C, et al. Manifold mixup: Better representations by interpolating hidden states[C]//International Conference on Machine Learning. PMLR, 2019: 6438-6447.

[10] Zhang H, Cisse M, Dauphin Y N, et al. mixup: Beyond Empirical Risk Minimization[C]//International Conference on Learning Representations. 2018.

[11] Yun S, Han D, Oh S J, et al. Cutmix: Regularization strategy to train strong classifiers with localizable features[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2019: 6023-6032.

[12] Singh R, Bharti V, Purohit V, et al. MetaMed: Few-shot medical image classification using gradient-based meta-learning[J]. Pattern Recognition, 2021, 120: 108111.

[13] Nichol A, Schulman J. Reptile: a scalable metalearning algorithm[J]. arXiv preprint arXiv:1803.02999, 2018, 2(3): 4.

[14] Spanhol F A, Oliveira L S, Petitjean C, et al. A dataset for breast cancer histopathological image classification[J]. Ieee transactions on biomedical engineering, 2015, 63(7): 1455-1462.

[15] Zou J, Ma X, Zhong C, et al. Dermoscopic Image Analysis for ISIC Challenge 2018[J]. 2018.

[16] Jantzen J, Norup J, Dounias G, et al. Pap-smear benchmark data for pattern classification[J]. Nature inspired Smart Information Systems (NiSIS 2005), 2005: 1-9.

[17] Hsu K, Levine S, Finn C. Unsupervised Learning via Meta-Learning[C]//International Conference on Learning Representations. 2018.

[18] Caron M, Bojanowski P, Joulin A, et al. Deep clustering for unsupervised learning of visual features[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 132-149.

[19] Maicas G, Bradley A P, Nascimento J C, et al. Pre and post-hoc diagnosis and interpretation of malignancy from breast DCE-MRI[J]. Medical Image Analysis, 2019, 58: 101562.

[20] Yao H, Wang Y, Wei Y, et al. Meta-learning with an adaptive task scheduler[J]. Advances in Neural Information Processing Systems, 2021, 34: 7497-7509.

[21] Kaddour J, Sæmundsson S. Probabilistic active meta-learning[J]. Advances in Neural Information Processing Systems, 2020, 33: 20813-20822.

[22] Wang B, Koppel A, Krishnamurthy V. A Markov decision process approach to active meta learning[J]. arXiv preprint arXiv:2009.04950, 2020.

[23] Litjens G, Kooi T, Bejnordi B E, et al. A survey on deep learning in medical image analysis[J]. Medical image analysis, 2017, 42: 60-88.

  • Keine Stichwörter