1. Introduction
Shape-aware semi-supervised image segmentation describes the process of splitting a digital image into multiple regions or objects. The main categories are semantic and instance segmentation. Semantic segmentation differentiates between different classes of objects, so all humans in one image are segmented as the same object. Instance segmentation classifies every human in the image as a distinct object.
A third category is Panoptic Image Segmentation, which combines the previous two methods. The Methods described in this blog post belong to the category of semantic segmentation.
Figure 1: Comparison of the different segmentation categories. 1. Original image; 2: Semantic Segmentation; 3: Instance Segmentation; 4: Panoramic Segmentation [1]
In semi-supervised learning the algorithm is trained with labelled and unlabelled data. Most often the amount of unlabelled data is much larger, than the amount of labelled data. Shape-aware means, that the algorithm is not only aware of the objects boundaries, but also how far voxels or pixels are away from that boundary.
In theory these methods can be applied to any use case where an object needs to be segmented from the background. These objects could be organs, parts of organs or tumours for example.
1.2 Problem Statement
Due to the lack of labelled data in medical image analysis methods that utilize unlabelled data must be developed. Labelling is time intensive and trained people are needed. Especially in segmentation, people have to accurately label every pixel or draw precise boundaries [2], [3], [4], [5].
1.3 Assumptions for Semi-Supervised Learning
Generally multiple assumptions about the data must be made for Semi-supervised algorithms to work. For our case the Smoothness Assumption is most important, which is given as:
Given two data points, x1 and x2, if they are close together in the same high density input space, then their labels in the output space should be close together in a regression task and the same in a classification task. This can be used transitively to predict labels for unlabelled data points. [6] [7]
1.4 Definition
Name | Description |
---|---|
Signed Distance Map (SDM) | Calculates the signed distance function. The closer a pixel or a voxel is to an object boundary, the smaller its value is. Depending on the definition, voxels or pixels inside the object boundary have a positive sign and outside a negative sign or the other way around. |
2. Methodology
2.1 Common semi-supervised algorithms
Different approaches exist for semi supervised learning. They can broadly be separated into inductive and transductive methods. Inductive methods can produce predictions for every object in the complete input space and after training, predictions are independent of each other. Transductive methods can only make predictions on objects they encounter during training. They will not generalize to unseen data during test time. These categories can further be divided, but for this case only consistency regularization, an inductive method, is of importance:
If an input is perturbed in a realistic way, the prediction should not change significantly. The model can be trained to have a consistent prediction with an unlabelled sample and its perturbed version. [6][7]
Methods described in the next sections are making use of the smoothness assumption and contrastive learning.
2.2 Semi-Supervised Shape-Aware Methods
All methods rely on the same basic structure. A 2D image or 3D volume X is feed into some kind of neural network architecture. This network jointly predicts the Signed Distance Map and Segmentation Map which are refined afterwards. During training a variety of different loss functions are applied which are usually differentiated in supervised and unsupervised loss.
Figure 2: General Structure of the Semi-supervised methods
2.2.1 Neural Network Architecture
In general two approaches can be found. Either a V-Net architecture is used or a mean teacher architecture.
V-Net
The network is similar to U-Net [8]. It consists of an encoder and decoder part, but is optimised for volumes instead of images. In the encoder part a latent space representation is constructed for the data, from which, in the decoder part, the SDM and segmentation map are constructed. The encoder part consists of multiple stages containing one to three convolutional layers, which should extract features and at the end of a stage a down-convolution is applied to reduce the size of the resulting feature map. The decoder part also consists of multiple stages which first apply an up-convolution and afterwards multiple convolutional layers. Between each stage of the encoder and decoder part horizontal connections exist, to make up for the loss of information in the encoder part. Used by [3] and [4]. [9]
Figure 3: structure of U-Net. On the left side is the encoder part, which also receives the input, on the right side is the decoder part. [9]
Mean Teacher Architecture
The architecture consists of two identical supervised architectures. One of them is called the student and the other one teacher. During training the same batch is given to both the student and teacher architecture, but random noise is added to both inputs independently of each other. After the prediction an additional loss, consistency loss is applied. The student weights θS are then update normally, whereas the teachers weights θT are an exponential moving average of the students weights . Usually α is set to 0.99. During testing only the Student network is used. Used by [2] and [5], where both the student and teacher network have a V-Net architecture. [10]
Figure 4: Mean-teacher architecture. Student and teacher model share the same architecture. [10]
2.2.2 Shape-Aware Semi-supervised 3D Semantic Segmentation for Medical Images [3]
The network consists V-Net as the backbone, which jointly predicts the SDM and segmentation map. Additionally a discriminator was trained, to differentiate if it is a predicted SDM or an SDM from the labelled training data. The discriminator is trained beforehand and consists of 5 convolutional layers. It receives as input the original input and the predicted SDM.
Figure 5: Overview of the architecture from SASSNet
Loss Function - Shape-Aware Semi-supervised 3D Semantic Segmentation for Medical Images [3]
The supervised loss for this network consists of 2 losses applied at different locations within the network.
- Lseg - Joint entropy and dice loss
- Lsdm - MSE Loss
The unsupervised loss for this network consist of the adversarial loss La, computed by the discriminator, which enforces consistency on the model prediction.
Beta is a time dependent Gaussian warm up function, which gets bigger over time. The optimization is done in an alternating manner, where first the adversarial loss is maximized and then the supervised loss minimized. The final linear loss function is given as:
2.2.3 SimCVD: Simple contrastive voxel-wise representation distillation for semi-supervised medical image segmentation [5]
The network consists of the mean teacher as a backbone, where each the student and teacher network have a V-Net architecture. They both predict a probability map, from which the segmentation map and SDM are constructed. In the end the predicted SDM is added onto the input and to both branches independent dropout masks are applied. Between the two resulting feature maps a contrastive loss is applied.
Figure 6: Overview of the architecture from SimCVD
Loss Function - SimCVD: Simple contrastive voxel-wise representation distillation for semi-supervised medical image segmentation [5]
The supervised loss Lsup for this network consists of 2 losses applied at different locations within the network.
- Lseg - Cross entropy and dice loss
- Lsdm - MSE loss
The unsupervised loss for this network consists of 3 losses applied at different locations within the network.
- Lpd - enforces consistency on the hidden patterns in the encoder
- Lcon - enforces consistency on the predicted probability maps of Student and Teacher Network - MSE
- Lcontrast - enforces consitency on the predicited SDMs during training
The final linear loss function is given as:
2.2.4 Shape-Aware Semi-supervised 3D Semantic Segmentation for Medical Images [4]
In this paper V-Net is used as a backbone. To get the segmentation map, a convolution and the sigmoid function are applied to the prediction of the Network along branch A. To get the SDM, a convolution and the tanh function are applied to the prediction of the Network along branch B.
Figure 7: Overview of the architecture from SBANet
PPM - Pyramid Pooling Module
The PPM works in parallel to the Network. It is used to extract multi-scale features of the input data. It takes the same input as the network and applies four parallel pooling branches with different sizes, a convolution, batch normalization and a ReLu. In the end the branches are upsampled to the original size and spliced together. The result is multiplied with (2-|SDM|), which strengthens the contour of the segmented object.
Figure 8: Overview of the PPM from SBANet
FFM - Feature Fusion Module
The input to the module is of size BxCxHxWxD, batch size, channels, height, width and depth, respectively. First channel shuffling is applied, which randomly mixes the channels c, then they are divided into g groups along the channel dimension. Each group is divided into two partitions, one of them is fed into the channel attention block and the other one into the spatial attention block. Afterwards everything is fused back together to have the same size as the input to the module.
Figure 9: Overview of the FFM from SBANet
Loss Function - Shape and boundary-aware multi-branch model for semi-supervised medical image segmentation [4]
The supervised loss for this network consists of 3 losses applied at different locations within the network.
LSDM - L1 loss and product loss between true and predicted SDM
LPSM - Dice Loss and Cross Entropy Loss between the true and the predicted Segmentation Map
LB - Boundary loss, focuses on the contour of the boundary between the true and predicted Segmentation Map
The unsupervised loss for this network consists of 2 losses applied at different locations within the network.
Lcl - MSE loss, enforce consistency between PSM1, PSM3 and SDM
Ladv - Discriminator predicts if PSM3 is real or predicted, works the same as in [3]
In the final loss function gamma is a time dependent Gaussian warm up function, which gets bigger over time:
2.2.5 3D Graph-S2 Net: Shape-Aware Self-ensembling Network for Semi-supervised Segmentation with Bilateral Graph Convolution [2]
This paper uses a mean teacher architecture, where each the student and teacher network have a V-Net architecture. As post-processing they use 3D bilateral Graph Convolution, which is separated into 3 steps. First graph projection is applied, then bilateral graph reasoning and then graph reprojection.
Figure 10: Overview of the architecture from 3D Graph-S2 Net
Graph Projection
The prediction X from the mean teacher model is mapped onto a set of node features H in the graph domain. In the graph domain pixels with similar features are aggregated to one node as an anchor. This is done simultaneously for the SDM and segmentation map. a denotes the semantic segmentation map and b the geometric SDM.
Bilateral Graph Reasoning
The goal is to model intra- and inter-task relations and diffuse information between the graph domain of the SDM and segmentation map to capture co-occurence relations over two graphs.
Graph Reprojection
The updated features from the graph domain are mapped back to the coordinate domain.
Figure 11: overview of the post-processing step to get the SDM and segmentation map from [2].
Loss Function - 3D Graph-S 2 Net: Shape-Aware Self-ensembling Network for Semi-supervised Segmentation with Bilateral Graph Convolution [2]
The supervised loss for this network consists of 2 losses applied at different locations within the network.
- LSseg - joint cross entropy and dice loss between real and predicted segmentation map
- LSsdm - MSE loss between real and predicted segmentation map
The unsupervised loss for this network consists of 2 losses applied at different locations within the network.
- LCseg - MSE between prediction of student and teacher, student is trained, teacher uses EMA weights
- LCsdm - MSE between prediction of student and teacher, student is trained, teacher uses EMA weights
The final linear loss function is given as:
3. Results
The dataset used in all the methods is the Left Atrial 2018 dataset, which contains 100 3D Gadolinium-Enhanced MRI scans with a resolution of 0.625x0.625x0.625 mm3[11]. Each scan contains 88 slices and the size of a scan may vary between the patients. Comparing all the results it can clearly be seen, that SimCVD performs in nearly all the settings the best, even though they have one of the most simplest architectures. But comparing the results to V-Net it can be seen, that all methods managed to incorporate the unlabelled data in a beneficial way.
Method | # labelled scans | # unlabelled scans | Dice[%] | Jaccard[%] | ASD[voxel] | 95HD[voxel] |
---|---|---|---|---|---|---|
V-Net | 80 | 0 | 91.14 | 83.82 | 1.52 | 5.75 |
V-Net | 16 | 0 | 86.03 | 76.06 | 3.51 | 14.26 |
3D Graph- S2Net | 16 | 64 | 89.83 | 81.69 | 2.12 | 6.68 |
SASSNET | 16 | 64 | 89.27 | 80.82 | 3.13 | 8.83 |
SimCVD | 16 | 64 | 90.85 | 83.80 | 1.86 | 6.03 |
SBANet | 20 | 52 | 88.4 | 79.5 | 2.56 | 9.773 |
V-Net | 8 | 0 | 79.99 | 68.12 | 5.48 | 8.99 |
3D Graph- S2Net | 8 | 72 | 87.94 | 78.90 | 2.32 | 8.99 |
SASSNET | 8 | 72 | 86.81 | 76.92 | 3.94 | 12.54 |
SimCVD | 8 | 72 | 89.03 | 80.34 | 2.59 | 8.34 |
SBANet | 5 | 67 | 80.6 | 69.2 | 3.445 | 14.253 |
Table 1: results from performance evaluations of the different methods
4. Ablation Study
All of the papers performed ablation studies, to show the effects of different components on the network. All methods have in common, that using the SDM improved the model performance. All the model performances are denoted with dice loss.
SimCVD [5]
Removing the SDMs hurts the model performance by -0.79% and adding the contrastive loss Lcontrast had the biggest impact on the model performance with an improvement of 3.9%. In contrast removing the SDM loss nearly had no negative effect.
3D Graph-S2Net [2]
Introducing the SDM also shows a clear performance gain of 2.32%. The bilateral graph convolution shows a performance improvement of 2.58%.
SASSNet [3]
Introducing the Discriminator shows the clearest performance improvement with 5.69%. Using the SDMs also shows an improvement of 1.13%.
SBANet [4]
The consistency loss on PSM1, PSM3 and SDM shows the biggest improvement when using unsupervised data with 1.08%. In general including unsupervised data shows a total improvement of 1.1%.
5. Own Review
In General all papers created new approaches to incorporate the SDM and the unlabelled data, which benefitted the overall model performance. An Advantage of SimCVD [5], Graph-S2Net [2] and SASSNet [3] is, that they us the same data set with the same ratio between labelled and unlabelled data, which makes it easy to compare them. This is missing in SBANet [4] what makes it hard to compare them properly with other methods. Graph-S2Net [2], SASSNet [3] do not show the influence of the loss functions and all papers are missing out on showing how different hyperparameters influence the training and result. Also the papers do not show the computational complexity of their methods.
References
[2] Huang, H., Zhou, N., Lin, L., Hu, H., Iwamoto, Y., Han, X. H., ... & Tong, R. (2021, September). 3D Graph-S 2 Net: Shape-Aware Self-ensembling Network for Semi-supervised Segmentation with Bilateral Graph Convolution. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 416-427). Springer, Cham.
[3] Li, S., Zhang, C., & He, X. (2020, October). Shape-aware semi-supervised 3D semantic segmentation for medical images. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 552-561). Springer, Cham.
[4] Liu, X., Hu, Y., Chen, J., & Li, K. (2022). Shape and boundary-aware multi-branch model for semi-supervised medical image segmentation. Computers in Biology and Medicine, 143, 105252.
[5] You, C., Zhou, Y., Zhao, R., Staib, L., & Duncan, J. S. (2022). Simcvd: Simple contrastive voxel-wise representation distillation for semi-supervised medical image segmentation. IEEE Transactions on Medical Imaging.
[6] Van Engelen, J. E., & Hoos, H. H. (2020). A survey on semi-supervised learning. Machine Learning, 109(2), 373-440.
[7] Ouali, Y., Hudelot, C., & Tami, M. (2020). An overview of deep semi-supervised learning. arXiv preprint arXiv:2006.05278.
[8] Ronneberger, O., Fischer, P., & Brox, T. (2015, October). U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.
[9] Milletari, F., Navab, N., & Ahmadi, S. A. (2016, October). V-net: Fully convolutional neural networks for volumetric medical image segmentation. In 2016 fourth international conference on 3D vision (3DV) (pp. 565-571). IEEE.
[10] Tarvainen, A., & Valpola, H. (2017). Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Advances in neural information processing systems, 30.
[11] http://atriaseg2018.cardiacatlas.org