Blog post written by: Taha Ahmed

Based on: Peng, W., Adeli, E., Bosschieter, T., Park, S. H., Zhao, Q., & Pohl, K. M. (2023, September 7). Generating realistic brain MRIs via a conditional diffusion probabilistic model. arXiv.org. https://arxiv.org/abs/2212.08034

1. Introduction

As the study and development of machine learning advances, so does it's benefits in the medical industry. One promising use case of ML, is in the medical imaging space where we use machine learning techniques on MRI images for a wide variety of tasks such as, but not limited to, Tumor detection, Disease classification, Disease progress monitoring etc.


However, running machine learning algorithms on MRI data faces some hurdles, namely:

  1. Limited Data Availability: ML algorithms, especially those involving deep neural networks, require large amounts of labelled and annotated training data, which is not as readily available in the medical imaging space as other fields.
  2. Data Diversity: We also require our training data to represent diverse conditions that are found in real life. Real world MRI datasets may suffer from class imbalances, where certain conditions or characteristics are underrepresented.
  3. Privacy considerations: Accessing and sharing real patient MRI data for research purposes can raise ethical concerns as patients may not have given explicit consent to the processing of their data in this manner.

And so to address these hurdles, we turn to synthetic MRIs, artificially generated MRIs that don't represent any real patient information, generated through deep learning pipelines, and we augment these in our real datasets in order to boost our volume and diversity. However, you may be thinking, if we're generating these synthetic images from deep learning pipelines, won't we run into the same issues where we need real MRI data in our training set? And that is correct. Generating synthetic MRI's is difficult, however a number of techniques have risen up over the past decade which attempt to solve these problems in their own ways and we'll take a brief look at them before describing a new approach that aims to be the best in class.

2. Approaches


Conditional ModelsUnconditional Models (GANs)Diffusion Models
DescriptionThese models generate images whilst conditioning on some input. In our case, we can generate high quality synthetic MRI by conditioning on real MRI of the same patient.These models generate images using random noise vectors instead of conditional inputs of other MRI slices.These models learn the fine-grained mapping between data distribution and Gaussian noise as a gradual process by adding small amounts of noise to an image over 1000s of timesteps and then training a neural network to reverse the process by predicting the amount of noise added to an image.

Typical Architecture



Pros

These models produce high quality synthetic MRIs that closely match the distribution of those present in the training data.

These models are not dependent on a conditional input and hence don't require as much training data to produce realistic images.These models produce the highest quality outputs and are more stable than GANs due to the gradual noising and denoising methodology.

Cons

They require a very large dataset to produce good results which we lack in our case. Furthermore, the outputs don't generalize well to the distribution of Real MRI's not present in the training set, hence they fail to diversify.

They tend to be very computationally expensive and they are prone to mode collapse i.e they memorize a discrete distribution which fools the discriminator quite well, and they keep producing outputs from that distribution.These are ever more expensive than GANs as in order to produce one synthetic MRI, the noise vector would have to be reversed gradually over thousands of steps through the reverse diffusion process. Producing a 3D MRI in this manner would take a large amount of computational resources.

As we can see, each of our existing approaches come with their own set of pros and cons. The cDPM that's proposed in this paper aims to improve upon the existing approaches by building on top of the structure of a diffusion model and deliver the following:

  • High quality synthetic 3D MRI that look more realistic than GAN based approaches.
  • A distribution closer to the distribution of real MRI, hence maintaining the diversity of real world samples.
  • To be a more computationally efficient approach than using 3D diffusion models.

Let's now look at the working of the cDPM and how it delivers upon these promises.

3. Conditional Diffusion Probabilistic Model (cDPM)

The working of a cDPM can be divided into 2 stages, the training stage and the generation stage. The Training Stage consists of running the Forward Diffusion Process to generate training samples with which we train our denoising neural network, and the Generation stage consists of running the Reverse Diffusion Process in an iterative manner to produce a full synthetic 3D MRI.

3.1. Training Stage

3.1.1. The Data

We start our training process by performing some data processing. In order to understand the training data, we first need to understand what exactly a 3D MRI is? A 3D MRI is nothing but a collection of several 2D MRI slices that have been collected from the MRI scanner at different depths and angles, which can then be reconstructed together to form a full 3D volumetric image (Fig. 5). Much akin to a video, which is simply a combination of several 2D image frames, all the 2D slices of an MRI are interdependent as they describe a full 3D MRI and we leverage those dependencies when training our model. 

3.1.2. Sampling

We start by taking a 3D MRI from our training set which consist of 128 slices and we assign each of those slices an index. In order to ensure that indexing of slices remains consistent across training examples, we perform affine registration to a template. After assigning indexes, we arbitrarily sample some of these slices into 2 sets; a conditional set C and a target set P. The number of slices that we sample into these sets is a hyperparameter that can be tuned based on the available computational resources, for the case of this experiment, we impose the limit len(C)+len(P)\leq 20 i.e the total number of slices between both sets must not exceed 20. In Fig 6. we can see that we've sampled slice 2 to our conditional set and slices 46, and 7 to our target set. The only other restriction we must follow is that our conditional set may be empty however, our target set must always contain at least one slice. This is because the forward diffusion process runs on the target set, hence it cannot be empty.

3.1.3. Forward Diffusion Process

We then combine the slices in our target set to form a new 3D sub-volume X^P  that we then run through the forward diffusion process i.e we gradually add noise to the volume in order to make it indistinguishable from random noise. This generates noisy training samples X^P_t at timestep t which we can use to train our denoising neural network. However, unlike a normal diffusion model where we would simply feed, as input to the model, X_T and the timestep t, we feed in some more information.

3.1.4.  Network Architecture 

The network architecture follows from a normal diffusion model which has a U-net structure consisting of encoder-decoder blocks. We've also added multi-head self attention blocks to the decoder block. The inputs to the network include the following:

3.1.4.1. Timestep T

We pass the timestep T to the time embedding module which encodes the timestep T as

Embedding(t) = [cos(tw^\frac{-2d}{Q} ),sin(tw^\frac{−2d}{Q})]

where t is the timestep, w is set to 10000 to control the minimum frequency of the embedding, Q represents the number of dimensions of the embedding and d = 1,2 ..... Q/2 represents the index of the embedding dimension. This embedding is passed to each layer of the model.

3.1.4.2. The input volume

The 3D sub-volume that the network takes as input contains both our conditional slices as well as the target slices we passed through the forward diffusion process. The slices are sorted based on their index and the entire volume is passed through the network as the network aims to reconstruct the noisy slices. Since the network is only looking at a maximum 20 slices (per our previous limitation) in one go, this process ends up being less computationally expensive than trying to reconstruct a full 128 slice volume.

3.1.4.3. Slice Indexes

We also pass the indexes of the conditional and target slices as positional encodings to the attention module to compute our attention weights. By choosing different combinations of conditional and target slices, we allow our model to learn the spatial distances between the slices and capture short and long term dependencies between the slices based on how close or far their indexes are respectively. 

Putting it all together, the loss function that our model aims to minimize is the L2 norm of the predicted noise and the actual noise added to the image. 

Loss(θ) := E_{X_0∼q,ϵ∼N(0,I),C+P \le τ_{max},t} [||ϵ−ϵ_θ(X^P_t, X^C,C,P,t)||^2_2]

This loss function is the same as a normal diffusion model except with the added terms representing our new inputs to the model.

3.2. Generation Stage

Once our network is trained, we are now ready to run the reverse diffusion process to generate new synthetic MRI. This process functions in an iterative manner and is divided into N stages where N = ceil(128/len(P))So If we feed our model 10 target slices in each stage, then it would take 13 stages to generate 128 slices which would = 1 synthetic 3D MRI. 

In Stage 1, we pass our model no conditional slices, only noise as target slices and we pass the indexes of the target slices as well. This allows our model to always start in an unconditional manner so as to not produce outputs that are limited in diversity because of a fixed conditional starting point. Once our model has denoised them, we use those synthetic slices as the conditional slice inputs in our next stage, and we pass in len(P) more noisy target slices. We make sure to set the indexes of the conditional slices equal to the indexes of the target slices from the previous stage we set the target slice indexes to the subsequent set of indexes that we want to denoise. By taking conditional slices into account when generating new target slices, the model produces a more coherent and high quality 3D MRI volume.


4. Results

4.1. Qualitative Results

Looking at the synthetic MRI produced by the cDPM in Fig 11, we can see that the scans look realistic with no blurring and clear grey matter boundaries. The 5 different samples also have sufficient diversity amongst them with no 2 samples sharing the same shape and structure as each other. When comparing a synthetic MRI from the cDPM vs those produced by other GAN based methods per Fig 10, we see that the cDPM MRI looks far sharper, with less blurring and very clear boundaries, thus delivering on one of our initial promises of generating higher quality MRI's than GAN based methods. The synthetic MRI produced by the 3D-DPM has a sharper visualization in one of the 3 views but it doesn't produce a valid result in the other 2 views. This is because, according to the researchers, given the same computational resources, the 3D-DPM was unable to produce a 128 slice MRI in one go as it was far too computationally expensive, and instead was only able to produce a 32 slice MRI which lacks the sufficient level of information necessary to reproduce a full 3D MRI. Thus we deliver upon another one of our initial promises of being more computationally efficient than a standard diffusion model.

4.2. Quantitative Results


MS-SSIM (closest to 79.2)MMD ↓FID-A ↓FID-C ↓FID-S ↓

3D-VAE-GAN

88.35.15320247398
3D-GAN-GP81.015.7141127281
3D-α-WGAN82.613.2121116193
CCE-GAN81.53.5469.4869191
HA-GAN36.82264771090554
3D-DPM79.715.2188--
cDPM**78.63.1432.445.891.1

We generate 200 synthetic MRI from each method and compute the following 3 metrics:

  1. Multi-scale structural similarity score (MS-SSIM): A measure of diversity within the 200 samples with 0 being perfectly dissimilar and 100 being perfectly similar, and our baseline is the diversity observed in the real MRIs which scored a value of 79.2%.
  2. Maximum mean discrepancy (MMD): A measure of how similar the distribution of the real MRI's is with the synthetic MRI.
  3. Fréchet Inception Distance (FID): A measure of similarity between the extracted features of the real MRI vs the synthetic MRI. We compute this across all 3 views.

The resulting figures generate the following key insights:

  • The diffusion models have the closest MS-SSIM to the real samples with the 3D-DPM just barely edging out the cDPM.
  • The cDPM produces synthetic MRI which most closely match the distribution and features of the real MRI's per the MMD and FID.
  • The CCE-GAN produces synthetic MRI's which most closely match the distribution of real MRI's amongst GAN based methods, on account of it being a conditional model.
  • The 3D-DPM has missing scores for FID-C and FID-S as it only produces a 32 slice MRI and is missing the necessary slices from these views

5. Conclusions

Looking back at our original promises of the cDPM, we can see that the cDPM delivers on all of them. The cDPM produces higher quality 3D MRI than GAN based methods, with a distribution that follows very closely to that of the real MRI samples and it does all this whilst being more computationally efficient than a standard 3D-DPM on account of it's ability to train on smaller volumes and yet, still be able to produce a 128 slice MRI thanks to it's iterative nature.

6. References

[1] Khaliki, M.Z., Başarslan, M.S. Brain tumor detection from images and comparison with transfer learning methods and 3-layer CNN. Sci Rep 14, 2664 (2024). https://doi.org/10.1038/s41598-024-52823-9

[2] https://www.iguazio.com/glossary/diffusion-models/

[3] Zhang H, Shinomiya Y, Yoshida S. 3D MRI Reconstruction Based on 2D Generative Adversarial Network Super-Resolution. Sensors. 2021; 21(9):2978. https://doi.org/10.3390/s21092978

[4] Peng, W., Adeli, E., Zhao, Q., & Pohl, K. M. (2022, December 15). Generating realistic 3D brain MRIs using a conditional diffusion probabilistic model. arXiv.org. https://arxiv.org/abs/2212.08034v1

7. ChatGPT prompts used

  1. Explain why we need synthetic MRI's in the first place
  2. What are some common use cases of machine learning with MRI's
  3. What are conditional models
  4. What is the MS-SSIM
  5. What is the difference between the MMD and FID
  • Keine Stichwörter