Abstract
In this blog post, we explore transfer learning techniques that leverage prior knowledge acquired by 2D networks and apply it to 3D networks, with a specific focus on transformer-based models. We explore how 2D knowledge can be effectively utilized and transferred to enhance the performance and capabilities of 3D models, shedding light on the potential benefits of this approach.
Introduction
Motivation
Extensive research has led to significant advancements in 2D networks across various computer vision tasks such as image classification, semantic segmentation, and object detection. With the development of novel pre-training techniques, these 2D networks have achieved remarkable results. Given the availability of large and diverse 2D datasets and the wealth of semantic and geometric knowledge they capture, it raises the question: why not leverage this knowledge to enhance the performance of 3D models and facilitate their training process? By incorporating the insights gained from 2D representations into 3D models, we can potentially improve their understanding and processing of 3D representations. We can then bridge the gap between 2D and 3D domains, potentially unlocking new possibilities and advancements in the field of 3D computer vision.
Problem Statement
With the rising prevalence of 3D medical imaging modalities such as MRI and CT, researchers are actively working on the development of learning-based models that can automate image interpretation workflows. The primary objective is to alleviate the workload of radiologists while simultaneously enhancing patient outcomes through faster and more precise diagnoses.
Due to the success of 2D transformer-based models in various vision tasks, there has been a surge of research exploring 3D transformer-based approaches, which allow global feature reasoning by capturing long-range dependencies, and offer high interpretability and scalability. Unfortunately, there is a scarcity of available 3D training data, due to the high cost and data protection concerns associated with acquiring medical images. Moreover, 3D transformers are computationally expensive models because of their complex architecture and the large number of parameters. Training them from scratch can require even more computational resources, as the model needs to learn complex semantic and geometric representations from random initialization.
One way to address this, is to benefit from the large and diverse 2D labeled datasets as well as the extensively researched and refined 2D networks. Indeed, the knowledge derived from well-trained 2D networks is highly informative, and can serve as a valuable signal for 3D networks to learn optimal initial weights, eliminating the need for additional labeled data and sophisticated architectures.
Methodology and Evaluation
2D- Slices Division
The most straightforward approach involves utilizing pre-trained 2D models and fine-tuning them on 3D datasets in a slice-by-slice fashion, similarly to [4] and [5]. In other words, we split the 3D image volume into 2D slices and process them separately, then we aggregate the model predictions to get one prediction for the full 3D image. This method has its limitations: The model may lose the spatial context and interdependencies present in the 3D image. This can result in a limited understanding of the overall structure and relationships between different slices, and a suboptimal performance compared to methods that explicitly model the 3D nature of the data.
Weight Inflation
let's start by examining the structure of a conventional 2D vision transformer. Please keep in mind that, for the sake of simplicity, we will disregard the channel dimension.
As visualized in Figure1 below , the 2D vision transformer involves dividing the input image into equally sized patches. These patches are then flattened into vectors, which are further transformed into a sequence of embeddings using linear projection. After incorporating positional encoding, the resulting tokens are fed into the transformer encoder. This process enables the vision transformer to process and analyze the image data, capturing both local and global dependencies for effective representation learning.
On the other hand, the 3D vision transformer (presented in Figure2 ) receives 3D input images, which are split into 3D cubic patches. The tokenization process remains the same, with the only difference being the length of the input sequence and the size of the weights required for mapping the input vectors into the embedding space. Since the encoder didn't encounter any changes, we can initialize it with the pre-trained 2D weights. However we still need to determine how to effectively utilize the 2D knowledge within the tokenization layer.
Figure1: 2D vision transformer : tokenization of the image patches
Figure2: 3D vision transformer : tokenization of the image patches
Adapting pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation [1]
Inspired by I3D [15], which suggested transferring CNNs pre-trained on 2D images to 3D video inputs, by inflating the convolutional weights along the temporal axis, the authors of [1] suggest a simple yet effective method to benefit from 2D pre-trained weights in 3D medical image segmentation, based on the fact that videos can be considered as 3D data, with the third dimension being the time instead of depth, Besides the partitioning and embedding layers of the vision transformer are equivalent to a p strided convolutional layer with a p×p kernel.
Method ( Figure3 )
1) First, a 2D vision transformer is pre-trained on natural images via a combination of self-supervised and supervised learning.
2) The 3D voxel-images are sliced into small windows along the depth axis and each window is forwarded to the embedding layer as a sequence of cubic patches
3) In the 3D vision transformer, the encoder weights are initially copied from its 2D counterpart, while the weights of the embedding layer are inflated.
4) For the segmentation task, the UPerNet Decoder segments only the center slice of each input window of the 3D volume. Finally, the predictions of all windows are aggregated to obtain the final segmentation results.
Figure3: overview of the approach presented in [1]
The objective of weight inflation is to determine the appropriate initial weights for the additional dimension (of size K) based on the pre-trained 2D weights. The paper discusses two variants of weight inflation:
Figure4: Weight inflation methods in [1]
- Average Inflation assumes that input slices within a certain range of depth exhibit similarity. This approach involves copying the weights K times along the depth axis and dividing them by K to maintain the mean and variance of the input to the transformer encoder.
- Centering Inflation relies solely on information from the central slice and gradually incorporates information from neighboring slices. It assigns the 2D weights to the center-most slice and sets the weights for the remaining slices to zero.
Which Inflation Technique Is Better?
The paper presents an ablation study that compares weight inflation against random initialization and confirms that weight inflation significantly outperforms random initialization. Additionally, it highlights that centering Inflation is the preferred approach based on their experiments.
How Is The 3D Window Input Segmented?
Furthermore, the study reveals that predicting only the segmentation masks of the center slice of each input window yields better results compared to segmenting all the slices within the window; This could be attributed to the noise or irrelevant information that the model might encounter when processing more slices.
How Much Depth Information Do We Need?
Some experiments were conducted to investigate how much depth information is actually needed. It turned out, that increasing the size of the input 3D windows leads to a decrease in performance in terms of the Dice Similarity Coefficient and a potential risk of overfitting. The findings suggest that considering a context of 5 consecutive slices in the 3D volume is sufficient to capture the necessary information for an accurate segmentation. This suggests that too much information from a larger number of slices may introduce noise or unnecessary complexity, leading to a decline in the performance of the 2D backbone.
How Is The Discrepancy in Channel Size Addressed?
Typically, natural images used for pretraining have 3 RGB channels, while 3D medical images can have one or multiple channels. This difference is addressed by summing up the weights of the color channels to get weights for grayscale images, and if weights for more than one channel are required, average inflation is applied.
Results
Importance of Transfer Learning:
The experiments show an improvement by 11,18% in the performance of the model with pre-trained weights but no depth information included, compared to the same model initialized randomly.
Importance of Depth Information:
The baseline UNETR adjusts the transformer architecture to 3D input but doesn’t use transfer learning. Compared to a similar model which doesn’t include depth information, there is a 6,07% performance boost. This can be explained by improved spatial awareness, which leads to more consistent predictions throughout the image volume.
=> Combining both transfer learning and depth information achieves the state-of-the-art performance
Figure5: Comparison of models that use transfer learning and/or incorporate depth information in [1]
Transfer Effectiveness of Models Pre-Trained with Different Sources and Objectives:
Figure6 ):
The paper also investigates different pre-training strategies by varying the source data and the objective types (Interestingly, the performance of pre-training on medical images is inferior to pretraining on natural images. This could be associated with inter-class Similarities in human anatomy, which poses a challenge to learn distinct visual features and also to the small size of the medical dataset compared to the vast amounts of natural images.
Moreover, we notice that the Performance improved by additional supervision from image or video labels.
Figure6: Comparison of different pre-training techniques in [1]
Personal Opinion
Strengths 💪
- Easy to implement
- Generalizable, stable, robust : SOTA performance on most of the 11 tested datasets, using one single set of hyperparameters for all experiments
- Tested on different pre-trained transformers
- Ablation study for different inflation settings
Future Work 🤔
Since the main purpose is to show the importance of weight inflation not to achieve state of the art performance, future research directions could involve investigating pre-training techniques using diverse data sources, varied objectives, and varying amounts of data. Additionally, there is a need to find optimal hyperparameters for different imaging modalities, which would enhance the performance and generalizability of the method. Furthermore, it's important to develop efficient self-supervised learning methods tailored specifically for 3D medical images in order to adjust data scarcity.
COVID Detection and Severity Prediction with 3D-ConvNeXt and Custom Pretrainings [2]
This paper is the least relevant one. It focuses on enhancing the performance of the classification of the severity of lung damage , and the detection of SARS-CoV-2 infections based on 3D lung CT-scans (COV19-CT-DB dataset), rather than analyzing the effect of knowledge leveraging from 2D models.
Method( Figure7 )
Unlike the previous paper [1], this paper uses a CNN-based network, namely a 3D version of ConvNeXt. The backbone receives an input tensor T and performs computations on it.
To generate a classification label, for the severity prediction task, the output features from the last block are passed through a task-specific classification head, which processes the features to produce class probabilities.
On the other hand, segmentation masks are generated by upsampling and concatenating the output features from each block of the ConvNeXt, The concatenated features are then further processed by a segmentation head.
Figure7: Overview of the approach in [2]
With the aim of improving the model´s ability to handle three-dimensional CT-data, 4 pre-training techniques were suggested. Only one of them is related to 2D knowledge leverage, the others use 3D segmentation as a pre-text task and pseudo-labels for self-supervision:
- ImageNet model: initialize the backbone with inflated grayscale ImageNet pre-trained weights
- Segmentation model: initialize with ImageNet model → pre-train on a dataset for lung-lesion segmentation in CT scans
- Segmia model: Initialize with ImageNet model → generate pseudo-labels for COV19-CT-DB with segmentation model → pre-train using the pseudo-labels
- Multitak model: Initialize with ImageNet model → generate pseudo-labels for STOIC dataset with segmentation model → pre-train on STOIC dataset for classification (real labels) and segmentation (pseudo labels)
Interesting Findings:
By pre-training the model on lung lesion segmentation (pretext task), we can enhance its ability to localize damaged regions, and predict the severity class of lung damage (downstream task). This advantageous approach enables the utilization of large-scale models even when working with limited medical data.
Furthermore, pretraining the model on a classification task using the labelled STOIC dataset, despite the dataset having different categories than COV19-CT-DB, facilitates the development of a comprehensive understanding of the classification task.
Weight Inflation Techniques:
The paper presents 3 weight inflation techniques: the full Inflation , which is basically the previously mentioned average inflation (copy the weights along the new dimension), the 1G Inflation , which is similar to the centering inflation but multiplies the 2D weights with Gaussian weights such that the largest weights are assigned to the center-most slice, and the weights get closer to zero as we move towards shallower slices, and a new inflation idea: 2G Inflation , which sums up the 2D weights multiplied by Gaussian weights along 2 axes.
Figure8: Weight inflation methods in [2]
According to the reported results, the performance is best for 2G inflation in terms of the cross-validation metrics. Consequently, we assume that combining information from multiple dimensions is beneficial. The possibility arises whether utilizing 3G inflation could potentially yield superior results?
Results
Figure 9 showcases the comparison between the 4 pre-training approaches. The highlighted F1 scores prove again that initializing with 2D pre-trained weights is better than random initialization even for non-transformer-based models.
The table inPersonal Opinion
I chose to include this paper because it presents additional inflation techniques and validates the significance of 2D-3D knowledge transfer for various 3D architectures and downstream tasks.
An intriguing finding worth highlighting is that incorporating a pre-text task similar to the downstream task alongside a segmentation task leads to superior pre-training outcomes.
Figure9: Comparison of different pre-training techniques in [2]
Knowledge Distillation
Can we solve 3D Vision Tasks Starting from a 2D Vision Transformer? [3]
Figure10: Overview of the approach in [3]
The authors suggest a unified transformer-based architecture, where the backbone, namely the vision transformer ViT, is fixed and they vary the tokenizer, which converts the input data into embeddings of the same dimension, and the head, which are tailored to a specific down-stream task. The challenge is to reduce the data modality gap (spatially scattered point cloud data, and 3D voxels have richer semantic and geometric meanings than 2D images) and the task knowledge gap (3D visual understanding tasks demand a greater level of semantic understanding compared to 2D tasks)
Method
In the context of the 3D case, the authors mentioned 3 types of voxel tokenizers depicted in Figure11:
- The Naive Inflation: converts the sequence of 3D cubic patches in XYZ coordinate ordering to embeddings of the same size as the 2D embeddings. However, it's important not to confuse this with weight inflation since it does not rely on any 2D pretrained weights. Instead, a different embedding function is utilized to map 3D patches into embedding vectors
.
- the 2D Projection: treats the data as if it were in a 2D scenario. This involves compressing the information from the third dimension into a single token instead of using a sequence of tokens. In other words, for tokenized cubes originating from the same xy coordinates, the approach takes the average of their values. An advantage of this method is that it yields the same sequence length as in the 2D case.
- the Group Embedding: is employed to preserve richer semantic meaning compared to 2D projection. In this approach, the depth dimension is treated as a word group, and the embeddings of slices along the depth axis are consolidated into a single semantic token. To accomplish this, a 1D Transformer Encoder, such as the TimesFormer, is utilized.
Figure11: Tokenization of 3D voxel images in [3]
Teacher-Student Knowledge Distillation
Two methods are employed to leverage 2D prior knowledge:
1) The transformer backbone inherits the weights from a model pre-trained on 2D images, but a small learning rate in the first few epochs should be used to avoid forgetting the gained knowledge, which motivates the 2) method
2) Retrospecting from 2D cases by generalization: We benefit from a frozen teacher transformer pre-trained on ImageNet. When training the 3D student transformer on a mini-batch of 3D data, we pass a mini batch of ImageNet validation set to both models and compare their output. The student transformer uses the tokenizer and head of the teacher for the 2D classification task. Additionally to the 3D task loss function, the KL divergence measurement is computed to measure the model´s knowledge memorizing power.
Figure12: Knowledge-distillation in [3]
Results
3D Object Classification :
Despite not achieving state-of-the-art performance and falling short of surpassing all baseline models, the results of the Simple3D-Former remain promising. It outperforms most of the CNN-based baselines and shows competitive performance with transformer-like models. This is particularly noteworthy considering that the Simple3D-Former has fewer parameters and does not heavily customize the core network to handle 3D data.
Figure13: 3D classification results in [3]
3D Point Cloud Segmentation :
Although the Simple 3D Former lacks complex geometric-aware designs and relies solely on prior knowledge from the ImageNet classification task, it still demonstrates relatively good performance in 3D point cloud segmentation compared to the baselines.
Figure14: 3D point cloud segmentation results in [3]
Is Transfer Learning a Performance Boost?
Leveraging pretrained knowledge from 2D models provides a significant improvement in accuracy. The teacher-student knowledge distillation was more beneficial than weight inflation but combining both yields the best performance.
The significance of 2D prior knowledge was also evaluated through experiments that involved using only a subset of the available 3D training data. This approach was adopted to mitigate computational costs and explore the impact of varying the batch size of the source task images. The findings revealed that increasing the amount of 3D data used during training led to improved accuracy. However, further enhancements in performance could be achieved by increasing the size of the 2D data used in knowledge distillation.
Personal Opinion
Strengths 💪
- A unified backbone
- Versatile, easy to deploy
- First to leverage 2D Vision Transformers (ViT) to their 3D counterparts
Limitations 👎
- Not tested for medical applications
- More focus on sparse point clouds, rather than voxel images
- No weight inflation for embedding layers
- Doesn’t test different pre-trained models
- Not State of the art performance for the majority of the experiments
Future Work 🤔
- Refine the 2D backbone to better handle 3D data
- Joint pre-training of the universal transformer
Overview of Other Interesting Approaches
Most transfer learning methods from the 2D domain to the 3D domain are limited to some computer vision tasks because they rely on mappings between 2D projections and 3D point clouds, which may not be readily available in the medical field. I'd like to briefly introduce four interesting methods, which might be applicable for 3D medical image processing.
The first method involves contrastive pixel-to-point knowledge transfer. The second method suggests a cross-model masked autoencoding. The following approach uses RGB natural images as pseudo 3D images for pre-training 3D networks. Lastly, weight inflation with a sparse embedding for joint training is introduced.
Contrastive Pixel-to-Point Knowledge Transfer [13]
The proposed method leverages contrastive pixel-to-point knowledge transfer to align pixel-level and point-level features in the same embedding space.
As illustrated in Figure15, using a pretrained 2D neural network (RestNet), we learn an initial weight for the 3D network ( Sparse Residual 3D U-Net34) from unlabeled datasets. Then, through supervised fine-tuning, we optimize the 3D network for various 3D downstream tasks.
For the pre-training, the input data is paired RGB and depth images. Given the camera intrinsic parameters, a back projection function π is used to generate a single view point cloud where each point has the RGB values of its corresponding pixel. The 2D images will be fed into the pre-trained and frozen 2D network which outputs a 2D feature map, while the 3D point cloud will be forwarded to a 3D network which we want to pre-train. This network outputs a feature vector for each 3D point. Then the learnable feature projection layers g2D and g3D map both 2D and 3D features into a shared latent space. The back projection is again applied to align each pixel feature vector with its corresponding point feature vector.
Interesting Findings:
- The suggested method is better than the previously presented knowledge distillation method because it doesn’t focus on minimizing the distance between global features or classification logits, but compares local features.
- Both encoder and decoder pre-training is advantageous.
- Self-supervised pre-training is more helpful when the training dataset is smaller.
- A good pre-training can compensate for the backbone size limitation.
Figure15: Overview of the approach in [13]
Advancing 3D Medical Image Analysis with Variable Dimension Transform based Supervised 3D Pre-training [10]
In 3D medical imaging, researchers typically combine each 3 adjacent slices to create pseudo RGB images, allowing them to utilize pre-trained 2D netzworks for 3D tasks. This process preserves rich 3D structural information in the 2D space through color encoding. By reverse reformulating natural images into pseudo 3D inputs, a large-scale 3D datasets can be constructed for fully-supervised 3D pre-training, taking advantage of the extensive annotations available in the natural image domain.
Interesting Findings:
- Better results if we freeze the weights of the stem layers and the first ResBlock; shallow features learned from the pseudo-3D data can be well generalized to medical data without any further fine-tuning
- When choosing a suitable pre-trained model, it is important to consider the similarity between the pretext task and the target task.
- In the case of small-scale datasets, the impact of pre-trained weights on the final performance is more significant.
Figure16: Overview of the approach in [10]
Rethinking Video ViTs: Sparse Video Tubes for Joint Image and Video Learning [11]
The idea of this paper is to use sparse tubes to generate tokens of 2D and 3D patches. They vary the tube shapes and the spatio-temporal stride, in order to reduce the complexity (generate shorter embedding sequences). They also modify the offset, i.e. the coordinates of the starting point of each tokenization tube, in order to reduce the overlapping of the tubes and thus maximize the captured information. Making the kernel size smaller means we want to capture small local details, and extending the tube on the depth axis means we want to include semantic expressions of more adjacent slides (long action in the case of videos).
The pipeline ( Figure17) can be summarized as follows:
Initially, a small ViT encoder is trained simultaneously on images and videos. Subsequently, the tokenization tubes obtained from this training process are transferred to a new large Image pre-trained ViT model. The final layers of this ViT model are then fine-tuned using small video datasets.
Strengths:
- It enables a versatile visual backbone that can easily adapt a ViT architecture to handle videos.
- It supports simultaneous understanding of both images and videos → allows interaction between both data modalities
- It provides a scalable solution for video understanding, capable of leveraging pre-trained ViT models, including large ones.
Figure17: Overview of the approach in [11]
Joint-MAE: 2D-3D Joint Masked Autoencoders for 3D Point Cloud Pre-training [14]
Once again, this approach affirms that multi-modality learning achieves more robust performance than single-modality learning.
First, the 3D point cloud is perspectively projected using a random view in order to acquire a 2D image. Subsequently, two hierarchical 2D-3D embedding modules map the input data into tokens of the same size. Those tokens are randomly masked with a high ratio (0.75%). Only the visible tokens are concatenated and fed to the joint encoder. Two learnable modality tokens, M2D and M3D, are added respectively to the 2D and 3D tokens in order to differentiate between both modalities in the attention mechanism. Next, a joint decoder, comprising a modal-shared and 2 model-specific components, reconstructs the hidden point cloud coordinates and image pixels.
Figure18: Overview of the approach in [14]
In order to enhance the interaction between 2D-3D representations, the authors introduce local-aligned attention mechanisms within the joint encoder, together with a cross-reconstruction loss for self-supervision:
- In local-aligned attention, apart from self-attention within the same modality, attention scores are also calculated for 2D and 3D tokens, which are geometrically correlated.
- The cross reconstruction loss measures the similarity between the 2D projection of the reconstructed point cloud and the reconstructed 2D depth map. It is added to the separate 2D and 3D reconstruction losses.
Conclusion
In conclusion, the challenges, posed by the data modality gap and task knowledge gap, make it hard for models relying on pre-training simple backbones and leveraging 2D prior knowledge to achieve state-of-the-art performance, compared to 3D models that extensively exploit the structural and geometric information present in 3D data. The latter models tailor their architectures specifically for the unique demands of 3D vision tasks,and fully utilize the richness of 3D data. Therefore, it may be worthwhile to explore better pre-training techniques for 3D backbones in order to mitigate the scarcity of 3D datasets.
Reference
[1] Zhang, Yuhui, et al. "Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation." Machine Learning for Health. PMLR, 2022.
[2] Kienzle, Daniel, et al. "COVID detection and severity prediction with 3D-ConvNeXt and custom pretrainings." European Conference on Computer Vision. Cham: Springer Nature Switzerland, 2022.
[3] Wang, Y., Fan, Z., Chen, T., Fan, H., & Wang, Z. (2022). Can We Solve 3D Vision Tasks Starting from A 2D Vision Transformer?. arXiv preprint arXiv:2209.07026.
[4] Chen, Jieneng, et al. "Transunet: Transformers make strong encoders for medical image segmentation." arXiv preprint arXiv:2102.04306 (2021).
[5] Cao, Hu, et al. "Swin-unet: Unet-like pure transformer for medical image segmentation." European conference on computer vision. Cham: Springer Nature Switzerland, 2022.
[6] Jun, Eunji, et al. "Medical transformer: Universal brain encoder for 3D MRI analysis." arXiv preprint arXiv:2104.13633 (2021).
[7] Shamshad, Fahad, et al. "Transformers in medical imaging: A survey." Medical Image Analysis (2023): 102802.
[8] Guo, Ziyu, Xianzhi Li, and Pheng Ann Heng. "Joint-mae: 2d-3d joint masked autoencoders for 3d point cloud pre-training." arXiv preprint arXiv:2302.14007 (2023).
[9] Girdhar, Rohit, et al. "Distinit: Learning video representations without a single labeled video." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.
[10] Zhang, Shu, et al. "Advancing 3D medical image analysis with variable dimension transform based supervised 3D pre-training." Neurocomputing 529 (2023): 11-22.
[11] Piergiovanni, A. J., Weicheng Kuo, and Anelia Angelova. "Rethinking video vits: Sparse video tubes for joint image and video learning." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.
[12] Yu, Ping-Chung, Cheng Sun, and Min Sun. "Data efficient 3d learner via knowledge transferred from 2d model." European Conference on Computer Vision. Cham: Springer Nature Switzerland, 2022.
[13] Liu, Yueh-Cheng, et al. "Learning from 2d: Contrastive pixel-to-point knowledge transfer for 3d pretraining." arXiv preprint arXiv:2104.04687 (2021).
[14] Guo, Ziyu, Xianzhi Li, and Pheng Ann Heng. "Joint-mae: 2d-3d joint masked autoencoders for 3d point cloud pre-training." arXiv preprint arXiv:2302.14007 (2023).
[15] Carreira, Joao, and Andrew Zisserman. "Quo vadis, action recognition? a new model and the kinetics dataset." proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2017.