Abstract:This blog is mainly related to representation learning using generative models. It consists of four sections. First, I share a brief introduction to representation learning and an overview of classical generative models.Then comes to the model part, 3 generative models which are used to learn a meaningful representation are chosen which are EBM-GFlowNets, Flamingo VLM and ChatGPT. I will develop into more details on their motivations, methodologies and the result performance. Next, I will do a comparison with these three models and have a discussion about how they could apply in the medical field. Finally, a brief review will give a summarization.

Author: Mei Sun 

Table of contents:

1.Introduction

Nowadays, achieving a good model requires the assistance of good representation. However, when dealing with a large amount of raw data, manual designing features or labeling them can be prohibitively expensive. And usually different tasks need different features of a dataset. What if a model could learn a meaningful and compact representation for downstream tasks on its own? We could achieve this by using generative models. In the following two subsections, firstly I will introduce the concept of representation learning and generative model, then focus on more details for the selected three models in section2. 

1.1 Representation Learning


Representation learning is also called feature learning. It is a subfield of machine learning that focuses on discovering meaningful representations of raw data. Concretely, with representation learning, it is easier to extract useful information when building classifiers or other predictors[1]. We hope the model can explore the underlying structure of the data automatically. So, the goal of representation learning is to find a mapping function from the raw input data to a high dimensional feature space, where the data is more easily modeled and then used for different downstream tasks. This could be done via training a neural network to learn a set of parameters that have the ability to find the latent structure. This blog will focus on using generative models to learn meaningful and compact representations.  

1.2 Generative models


A generative model is a classic type of machine learning model that is aimed to generate new samples from the training model that are similar to the true data. As the following figure1 shows.  

                                                 

                Figure 1: Generative model

By learning the true training data distribution, the model acquires the ability to capture the underlying probability distribution of the data, which allows them to randomly generate new samples that are as close as possible to the training data.There are many different types of generative models.

Usually, we train a generative model by maximum the likelihood Pθ. Based on the different ways to model the true data density, the common classical generative models are roughly classified as the figure2 shown.       



Figure 2: Generative Models classification

2.Models

This blog prepares three generative models which are EBM-GFlowNets, Flamingo VLM and Chat GPT. They all learn representation of data in an unsupervised manner, which means that they usually do not require labeled data to learn useful features. In addition, they can be used for a variety of tasks, such as image synthesis, text generation and anomaly detection, etc. More details of these three models will be discussed in the following.

2.1 EBM-GFLowNets


EBM-GflowNets [2] is a novel probabilistic modeling algorithm for high-dimensional discrete data. We jointly train a GFlowNet with an energy function, so that GFlowNet learns to sample from the energy distribution while the energy learns with an approximate MLE Maximum Likelihood Estimation with negative samples from the GFlowNet [2].

2.1.1 Problem Statement 


The concept of energy-based model is from Gibbs distribution in statistical physics. In neural networks, it could be expressed as follows: 

𝑃_ϕ (x)=\frac{1}{Z} exp⁡(−𝐸_ϕ (𝑥)), \: \: where \: \: 𝑧 = ∫e𝑥𝑝(−𝐸_ϕ (𝑥))d𝑥\: 𝑜𝑟\: 𝛴_𝑥 exp⁡(−𝐸_ϕ (𝑥)) \: \: \: 𝑥\~ X_{train}

Eϕ  is called energy function, usually it's a neural network that parametrized by ϕ, we put a negative sign in front of Eϕ, in such a way, the input x with low probability could have high energy, input x with high probability could have low energy. EBM aims to maximize the probability of 𝑃ϕ, One difference is that not only maximizes the probability of xtrain, but also minimizes the probability of xsample which is randomly sampled from the training model.Figure3 shows the objective function.We try to pull up the probability of data points in training set,while pushing down the sampled data.These two forces are in balance iff 𝑃ϕ(x)=𝑃data(x). Here, the negative samples (xsample) also take into account, as this model is not normalized, increasing the un-normalized probability via changing ϕ doesn't guarantee that xtrain will become relatively more likely than others.

Figure 3: Training with energy function

  Like the normal training process, the energy function is trained through MLE (Maximum Likelihood Estimation) as loss function. To update the network's parameters, compute the gradient of negative log likelihood with respect to the model parameters ϕ. As the following equation shows.  

−∇_ϕ ℒ(ϕ)=\frac{1}{n} ∑_{i=1}^𝑛 ∇ _ϕ ⁡(𝐸_ϕ (𝑥_{train−𝑖}))−𝔼_{𝑃_ϕ(𝑥) } [∇ _ϕ (𝐸_ϕ (𝑥_{𝑠𝑎𝑚𝑝𝑙𝑒}) )][2]

From the above equation, we could find that the second part involves taking negative samples xsample from the EBM distribution, which means it has to compute the normalization constant Z. But evaluating Z requires an expensive search, it's intractable.Therefore, instead of evaluating Z, we could approximate the expectation of xsample by MCMC (Markov Chain Monte Carlo). Using MCMC to sample the negative data from EBM. Then the gradient of the loss function becomes the below shown.  

−∇_ϕ ℒ(ϕ)=\frac{1}{n} ∑_{i=1}^𝑛 ∇ _ϕ ⁡(𝐸_ϕ (𝑥_{train−𝑖}))−\frac{1}{n} ∑_{i=1}^𝑛 ∇ _ϕ ⁡(𝐸_ϕ (𝑥_{𝑠𝑎𝑚𝑝𝑙𝑒-i} ))

 However, when using MCMC as the sample strategy to sample the data from a high-dimensional distribution, it's common to encounter such a case, the sampled data has low diversity, low independence, and high correlation which will cause mode-mixing problems[3]. We could interpret it using the energy concept.MCMC tends to keep sampling around peaks, while higher probability means lower energy, and lower energy may make the sampler difficult to jump out of these peaks and explore other areas of the distribution. As a result, the samples are concentrated around these peaks, leading to low diversity and high correlation between the samples. Then the authors tackle this problem by using generative flow nets as an alternative sampler to MCMC.  

2.1.2 Methodology


 Generative flow nets perform sampling strategy via reinforcement learning.The stochastic sampling policies trained to generate samples by sequences of actions.The sequence which generates an object consists of a trajectory. Unlike conventional RL, which maximizes the reward function R(x), GFlowNets samples the data 𝑃ϕ  ∝ R(x), so that we can generate a diverse x with high-reward. Below is figure 4 illustrates the procedure of data sampled by GFlowNet in a 9-dimensional data space.  

Figure 4: Sampling by GFlowNets in a 9-dimensional data space.

First setting reward function R(x)=exp(−E𝝓(x)). For a d-dimensional binary data X = {0, 1}D. Starting from the initial state S0 which is all initialized with void, each action chooses a void entry (a grey pixel) and paints it into black(1) or white(0). Stop after D steps, i.e. when the vector is complete. Therefore,in GFlowNets, the trajectory balance (TB) is chosen as objective loss function.

ℒ(𝜃)=(log⁡\frac{Z_ϕ∏_{i=0}^{𝐷−1}𝑃_𝐹 (s_{𝑖+1} | s_𝑖; 𝜽)} {𝑅(s_D) ∏_{i=0}^{𝐷−1}𝑃_B (s_𝑖 | s_{𝑖+1}; 𝜽)})^2

When training GFlowNets, the sampling policy is represented by a forward trainsition model  𝑃𝐹 (s𝑖+1 | s𝑖; 𝜽), learn an auxiliary backward sampling model 𝑃B (s𝑖 | s𝑖+1; 𝜽).Optimize the parameter 𝜽 for trajectories s0→ s1→ · · · → sD sampled from the policy 𝑃𝐹. TB loss function ℒ(𝜃) to globally minimize for all trajectories, which means that make the marginal likelihood of sampling x proportional to a reward 𝑅(x). Figure 4 presents the TB training theorem for GFlowNets. 


Figure 4: Trajectory Balance training theorem

2.1.3 Result


This EB-GFN algorithm is validated on the Ising model. The Ising model is an elementary example of a Markov random field. The authors aim to model seven different distributions over 32-dimensional binary data in most cases. The following figure 5 shows the visualization of learned energy function by other methods. On the left top is the true data we want to generate. In the middle part shows the visualization results from the other three different baselines. The bottom part presents the result of choosing GFlowNets as a sampling strategy.   

                                       

  Figure 4: Visualization  result of  2D synthetic data with different methods

Table 1 displays the quantitative evaluation result on the both algorithms with two matrices. Negative log-likelihood (NLL) and maximum mean discrepancy (MMD). From this table and the above visualization result, we could notice that EBM-GFNets outperform the baselines on all datasets and metrics.                         

Table 1: Experiment result of  2D synthetic data with different methods

In addition, the authors also select some common image datasets to evaluate the models for generating images in discrete high-dimensional spaces(Table 2). The following figure 5 is the visualization of the Dynamic MNIST samples. Notice that EB-GFN models the details of certain modes better. see, e.g., the bottom right corner of the GWG and Gibbs.

                                                    Figure 5: Experiment result of discrete image modeling

Table 2: Experiment result of discrete image modeling. 

2.2 Flamingo VLM


We experience this word in a multimodal way. We see objects, hear sounds, feel texture, smell odors,and so on [11] .To let AI understand the world around us, it needs to be able to interpret such multimodal signals together.The ability to quickly learn new tasks from brief instructions is a crucial aspect of intelligence, but in computer vision, the prevalent approach to achieve this involves pre-training a model on a vast amount of data and then fine-tuning it for the desired task. This process requires a substantial amount of annotated data, typically thousands of data points, to be effective. So, Flamingo is a visual language model (VLM) for few-shot learning that was launched by DeepMind[4]. 

Flamingo is a visually conditioned autoregressive text generation model able to ingest a sequence of text tokens interleaved with images and/or videos, and produce text as output. Figure 6 displays one way we interact with Flamingo.

Figure 6: Examples of inputs and outputs obtained from Flamingo

Flamingo facilitates image and video understanding tasks through the use of simple example prompts. Flamingo not only handles visual question tasks but also performs image captioning, visual dialogue, and classification., Like the figure 6 shows, we first input an image with the corresponding text as prompt, the model takes the first two inputs as learning examples, we then input an image and a little text as prompt, flamingo can output the answer.  

2.2.1 Problem Statement 


When processing a sequence of text interleaved with images and/or videos, there are three challenges for Flamingo to address.

1.Supporting both images and videos.

Since images and videos are 2D structures with high dimensionality, how should we add such visual data into 1D sequence? The proposed solution is to introduce a perceiver resample module to generate a fixed number of visual tokens.

2.The interaction with image/video and text.

A language model (LM) trained only on text doesn’t know how to incorporate inputs from other modalities, once introducing such inputs to a pre-trained LM, how to keep the pre-trained model’s language understanding and generation capabilities fully intact. This paper tackles it by introducing tanh gating mechanism and Interleave cross-attention layers with frozen self-attention during training.

3.Obtaining multimodal dataset to induce good generalist capabilities

Scraping billions of web pages to gather a dataset leads to a weak matching problem. To avoid weakly related datasets, the authors combine such datasets with standard strong related paired image/text and video/text datasets.

Figure 7 is an overview structure of Flamingo.The key to its good performance lies in the innovative architectural components and effective pre-training techniques, which will be discussed in the following subsection.

Figure 7: Structure of the Flamingo model

2.2.2 Methodology


Challenge1: Supporting both images and videos.

This Perceiver resampler module is used to support both images and videos.It takes input with a variable number of image or video features from the vision encoder and outputs a fixed number of visual tokens. visual tokens usually are feature, color, shape, nature position of objects, etc

Figure 8: The Perceiver Resampler module

As figure8 illustrates, the input three the video frames, this model using ResNet as the visual encoder to obtain the visual features, then adding a learnt temporal position to each spatial grid of features corresponding to a given frame.after that these visual features are concatenated together and flatten into 1D,  donated as Xf. Then compress these visual features Xf into R tokens via Perceiver Resampler module. The core of this module is the attention mechanism.The model learns a predefined number of latent input queries X, it has the same size as output tokens.These latent queries are fed to a transformer stack and cross-attend to the flattened visual features Xf.The keys and values are computed from the concatenation of Xf, and the learned latent token X. After this , the visual inputs are re-sampled to a fixed and small number of outputs to significantly reduce the computation complexity in vision-text cross-attention.


Challenge2: The interaction with image/video and text.

Once the model obtains the fixed visual tokens Xf from the perceiver resampler, a gated cross attention mechanism is proposed to fuse images and text. As the figure 9 shown.

Figure 9: Gated xattn-dense layers


The authors used a pre tainted 70B parameter Chinchilla model as the language model. These pretrained LM blocks are frozen during the training of Flamingo (frozen the weights) to preserve the information and text generation abilities in the text-only language model. In order to condition the LM on the visual inputs, the gated cross-attention dense blocks are inserted. Additionally, a tangent gating mechanism is employed to maintain the original language model behavior at initialization and not change the learned features by the LM. Tangent gate value is initialized with 0, then gradually increases. In doing so ,the gating mechanism enables the model to smoothly transition from a text-only model to a visual language model during training, resulting in the ability to interact with multiple modalities.


The challenge of supporting interleaved visual and text data remains unresolved. This could be done by using per-image/video attention masking. Figure 10 illustrates the procedure.

Figure 9: Gated xattn -dense layers

Given such input, text interleaved with images/videos. Firstly, first process the input by inserting <image> tags  at the location of the visual data in also, plug in special tokens (<BOS> for “begining of sentence” or <EOC> for “end of chunk”). These images are processed independently by the Vision Encoder and Perceiver Resampler to extract visual tokens. Each text token they only cross attends to the vision tokens corresponding to the last proceeding image.Based on the visual data positions, the function ϕ:[1,L]↦[0,N ] indicates that what is the index of the image that it is supposed to cross attend to.( 0 if no visual data appears before the position). Multi-visual input support is implemented via the gated xattn-dense layers with causal masking over tokens, dark blue entries (non masked) and light blue entries (masked).By default,each text token is only allowed to focus on one image before it (this restriction improved performance). During the final prediction, each token can focus on all the previous text and image.


Challenge3: Obtaining multimodal dataset to induce good generalist capabilities

The dataset which is just directly scrapping from the webpage called M3W has weak matching problem.To induce good generalist capabilities, M3W dataset is combined with standard strong related dataset: Image-Text pairs data Video-Text pairs data. Figutr 10.
Concretely , this model is trained on the following dataset.M3W(185M images+ 182G text),ALIGN (1.8B images with alt-text),LTIP (312M images/text) and VTP(27M short video/text).To ensure the generality of the model, the dataset is not annotated.

Figure 10: Training dataset

2.2.3 Result


This model is evaluated via three ways. Figure 11

Figure 11: Overview of the results for the Flamingo models


In the left part, the largest Flamingo with 32 shots without using fine-tuning outperforms on six out of the 16 tasks compared to state-of-the-art fine-tuned models. For all 16 tasks where published few-shot results are available, Flamingo outperforms them by a large margin. From the middle and left part, larger model sizes and more few-shot examples lead to better performance.

2.3 Chat GPT


ChatGPT is a large language model chatbot published in November 2022 by OpenAI. It's a sibling model to InstructGPT [6], which is trained to follow an instruction in a prompt and provide a detailed response[5]. It can assist with a variety of tasks. Such as answering questions, providing information on a wide range of topics. Users interact with it in a conversational dialogue form.

2.3.1 Short introduction


Chat GPT is fine-tuned based on a model in the GPT-3.5 series[7] on massive data from the internet. It is trained with reinforcement learning with human feedback. Learning with human feedback not only increases the positive impact of large language models by training them to do what a given set of humans wants them to do [6]. But also it has the potential to make language models more helpful, truthful, and harmless[5]. We mainly hope Chat GPT can align with users’ intent.  

2.3.2 Methodology


Here I roughly summarize the learning phrases for Chat GPT into four steps. As shown in the illustration of figure 12:

Figure 12: Overview of learning phrases for Chat GPT

Step1: First train a GPT to predict the future word given the previously generated text. This is to equip the GPT model with the ability to perform question-answering.

Step2:Human feedback is utilized to refine the model's outputs, ensuring that it provides the correct answers and not unexpected ones.

Step3: To facilitate GPT's learning towards meeting human expectations and minimize manual effort, human feedback is provided through ranking the generated answers.

Step4: Using reinforcement learning algorithm to optimize model.

Generally, the goal of ChatGPT is to generate text that aligns with user intent while being helpful, truthful, and harmless.


2.3.2 Limitations


While ChatGPT is designed to produce helpful and appropriate responses, it is not immune to manipulation. Figure 13 demonstrates instances where a user employs various trickeries to prompt ChatGPT to generate illegal or inappropriate sentences. There are some limitations of Chat GPT[5]. The quality of ChatGPT's answers is dependent on the quality of the input provided. Therefore, expert guidance results in improved outputs. ChatGPT may not always provide correct answers and may generate plausible-sounding but incorrect or nonsensical responses.Fixing this issue is challenging, as: (1) During RL training, there’s currently no source of truth;(2) Training the model to be more cautious causes it to decline questions that it can answer correctly; (3) Supervised training misleads the model because the ideal answer depends on what the model knows, rather than what the human demonstrator knows[5]. 

Figure 13: Shortcomes of Chat GPT [8]


3. Comparison

3.1 Comparison


Table 3 presents a review of three selected models in this blog.                                                     

Model

Lean Way

Based on

Algorithm

Dataset

Flamingo VLM

Unsupervised training without fine-tune

Pre-trained Chinchilla LM 70B

Multimodal Learning

M3W   +ALIGN  +LTIP   +VTP

Chat GPT

Unsupervised training + supervised fine-tune

Fine-tuned on GPT3.5

Reinforcement Learning

massive data from the internet

EBM-GFlowNets

Unsupervised learning

Sample strategy: GFlowNets

Reinforcement Learning

High-dim. data

Table 3: Comparison table with three models

The above-mentioned three models exhibit increasing human-like performance, however, there is still much room for improvement in terms of understanding human language. In addition, a clean, high-quality, and massive dataset is crucial for optimal model performance, along with ample computational resources. A well-designed representation can also greatly enhance the final outcome.

3.2 Medical Application


Representation learning aims to find useful representations of the data. Useful representations have various applications in the medical field, particularly in the analysis of medical images like CT scans or MRI. Representation learning can learn features from these images that can support tasks like classification, segmentation[10], and detection.  Representation learning has applications beyond the medical image field. It can also be used to analyze biological sequences[9], such as DNA and protein sequences. By learning useful representations of these sequences, it becomes possible to perform tasks like classification, prediction, and analysis of molecular interactions. The power of language understanding and communication of Chat GPT has been acknowledged, making it a potential medical assistant, such as patient communication and mental health support, or extracting relevant information from vast data, summarizing medical articles, analyzing medical data, and generating text for medical reports. However, the medical field requires strictness and discretion, it can not perform any medical tasks ,thus caution should be taken when using Chat GPT. 

The illustration in Figure 14 presents a brief overview of the potential applications in the medical field. 

Figure 14: Short discussion of medical application

4. Summary

In this blog, we delve into the concepts of representation learning and generative models. We explain what energy-based models are and why GFlowNets are used as a sampling strategy. We introduce two highly-regarded visual language models, Flamingo VLM and Chat GPT. And provide a concise overview of the implementation methodologies related to these three models. Finally, we compare these models and explore their potential applications in the medical field.

Following is personal subjective for the 3 models,

EBM-GFlowNets:

Pros:It enables generalization across different modes of the distribution by exploiting the compositional structure revealed with the GFlowNet.It can model complex distributions.

Cons: In addition to training the energy function, the sampler must also be trained.It can be computationally expensive, especially when the dimensionality of the input data is high.

Flamingo VLM:

Pros: It has the ability to process sequences of visually and textually intertwined data with ease. Despite not requiring fine-tuning, Flamingo models possess the ability to swiftly adjust to diverse image and video comprehension tasks. It's flexibility beyond traditional vision and language benchmarks. 

Cons: Flamingo models may struggle to provide correct answers for complex scene picture or video comprehension tasks. Further improvement is necessary for Flamingo models to effectively apply in open-ended visual dialogue settings.

Chat GPT:

Pros: ChatGPT has been trained on a large corpus of text data, making it capable of generating human-like responses to questions and conversations. Like Flamingo, it also has high flexibility, which means Chat GPT can be fine-tuned for specific applications such as customer service, language translation, and question answering.Finally, it can handle a large number of queries simultaneously, making it suitable for large-scale deployment.

Cons: Lack of context awareness;  If ChatGPT has learned biases from the data it was trained on, which can result in incorrect or biased responses; It does not have the ability to think creatively or generate new ideas.

References

[1] Yoshua Bengio, Aaron Courville, Pascal Vincent.Representation Learning: A Review and New Perspectives IEEE Transactions on Pattern Analysis and Machine Intelligence 2013.

[2] Stefano Ermon, Yang Song. Energy-Based Models.CS236_lecture11.

[3] Dinghuai Zhang, Nikolay Malkin, Zhen Liu, Alexandra Volokhova, Aaron Courville, Yoshua Bengio. Generative Flow Networks for Discrete Probabilistic Modeling  ICML 2022.

[4] Jean-Baptiste Alayrac, Jeff Donahue, Pauline Luc, Antoine Miech, Iain Barr, et al.Flamingo: a Visual Language Model for Few-Shot Learning DeepMind 2022.

[5] John Schulman, Barret Zoph, Christina Kim, Jacob Hilton, Nick Ryder, et al.. Optimizing Language Models for Dialogue OpenAI2022.

[6] Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Ryan Lowe, et al. Training language models to follow instructions with human feedback. 2022.

[7] Website: model-index-for-researchers.

[8] Blog: Zvi . Jailbreaking ChatGPT on Release Day 2022.

[9] Moksh Jain, Emmanuel Bengio, Alex-Hernandez Garcia, Jarrid Rector-Brooks, Yoshua Bengio .et al. Biological Sequence Design with GFlowNets. ICML 2022

[10] Chao Huang, Hu Han, Qingsong Yao, Shankuan Zhu, S. Kevin Zhou.  3D U2-Net: A 3D universal U-net for multi-domain medical image segmentation. MICCAI 2019.

[11] Website:CARNEGIE MELLON UNIVERSITY 5000 FORBES AVENUE PITTSBURGH, PA 15213. Multimodal Machine Learning



  • Keine Stichwörter