Based on the paper written by Sjoerd van Steenkiste, Francesco Locatello, Jürgen Schmidhuber, Olivier Bachem

Table of Contents


Introduction


Learning good representations of high-dimensional sensory data is of fundamental importance to Artificial Intelligence. In the supervised case, the quality of a representation is often expressed through the ability to solve the corresponding task. But in reality, we have more unlabeled than labeled data, and we need a way that apply to more general real-world settings.

Following the successes in learning distributed representations that efficiently encode the content of high-dimensional sensory data, recent work has focused on learning representations that are disentangled.

A disentangled representation encodes information about the salient factors of variation in the data independently isolating information about each specific factor in only a few dimensions. It is often argued that a representation that is disentangled is desirable in learning to solve challenging real-world tasks. In a disentangled representation, information about an individual factor value can be readily accessed and is robust to changes in the input that do not affect this factor. Hence, learning to solve a down-stream task from a disentangled representation is expected to require fewer samples and be easier in general.

Disentangled representations have been found to be more sample-efficient, less sensitive to nuisance variables, and better in terms of (systematic) generalization.

Here comes the focus of the paper: Systematic evaluation on a complex task of a wide variety of disentangled representations obtained by training different models, using different hyper-parameters and data sets

Despite an increasing interest in learning disentangled representations, a precise definition is still a topic of debate.

In recent work, Eastwood et al.1 and Ridgeway et al.2 put forth three criteria of disentangled representations:

  • Modularity: implies that each code in a learned representation is associated with only one factor of variation in the environment
  • Compactness: ensures that information regarding a single factor is represented using only one or few codes
  • Explicitness: the mapping between factors and learned codes can be implemented with a simple (i.e. linear) model.

Combined, modularity and compactness suggest that a disentangled representation implements a one-to-one mapping between salient factors of variation in the environment and the learned codes.

Methodology


Evaluation of the usefulness of disentangled representations on abstract visual reasoning tasks that challenge the current capabilities of state-of-the-art deep neural networks.

Step 1 (Disentangling)

Methods to learn disentangled representations:

  1. β-VAE 3
  2. FactorVAE 4
  3. β-TCVAE 5
  4. DIP-VAE 6

We can view all of these models as Auto-Encoders that are trained with the regularized variational objective of the form:

\begin{equation} \mathbb{E}_{p(x)}[\mathbb{E}_{q_{\phi}(z|x)}[-\log p_{\theta}(x|z)]] + \lambda_1\mathbb{E}_{p(x)}[R_1(q_\phi(z|x))] + {\lambda}_2R_2(q_\phi(z)) \end{equation}

The output of the encoder that parametrizes q_\phi(z|x) yields the representation, while regularization serves to control the information flow through the bottleneck induced by the encoder.

Step 2 (WReNs)

They used the Wild Relation Network (WReN) to solve the abstract visual reasoning tasks.


The WReN is evaluated for each answer panel a \begin{equation} \in \end{equation} A = {a1, …., a6} in relation to all the context-panels C = {c1 ,…., c8} as follows:

\begin{equation}WReN(a,C) = f_{\phi} (\sum_{e_{1},e_{2} \in E} g_{\theta}(e_{1},e_{2})) , E = \{CNN(c1),.....,CNN(c8)\} \bigcup \{CNN(a)\}\end{equation}

Reasoning process:

  1. gθ is applied to all pairs of panel embeddings to consider relations between the answer panel and each of the context panels, and relations among the context panels
  2. \begin{equation} f_{\phi} \end{equation} produces a score for the given answer panel in relation to the context panels by globally considering the different relations between the panels as a whole.

Experimental Settings


Abstract reasoning tasks require a learner to infer abstract relationships between multiple entities (i.e. objects in images) and re-apply this knowledge in newly encountered settings

They constructed two new abstract RPM-like visual reasoning datasets based on two existing datasets for disentangled representation learning.

To construct the abstract reasoning tasks, they used the ground-truth generative model of the dSprites3 and 3dshapes4 data sets with the following changes:

  • dSprites: ignore the orientation feature for the abstract reasoning tasks. To compensate, add background color and object color as two new factors of variation.

dSprites

  • 3dshapes: retain all of the original factors but only consider four different values for scale and azimuth (out of 8 and 16).

3dShapes

The key idea is that one is given a 3*3 matrix of context image panels with the bottom right image panel missing, as well as a set of six potential answer panels. One then has to infer which of the answers fits in the missing panel of the 3*3 matrix based on relations between image panels in the rows of the 3*3 matrices.

They generate instances of the abstract reasoning tasks in the following way:

  1. Uniformly sample whether 1, 2, or 3 ground-truth factors are fixed across rows in the instance to be generated.
  2. Uniformly sample without replacement the set of underlying factors in the underlying generative model that should be kept constant.
  3. Uniformly sample a factor value from the ground-truth model for each of the three rows and for each of the fixed factors3.
  4. For all other ground-truth factors, also sample 3*3 matrices of factor values from the ground-truth model with the single constraint that the factor values are not allowed to be constant across the first two rows.

In recent work, Locatello et al.9 studied several disentanglement metrics:

BetaVAE score

  • focus primarily on modularity.

FactorVAE score

  • focus primarily on modularity.

Mutual Information Gap (MIG)

  • mostly focused on compactness.
  • is based on the matrix of pairwise mutual information between factors of variations and dimensions of the representation.

DCI Disentanglement score

  • focus primarily on modularity.
  • uses a random forest regressor to determine the relative importance of each feature.

Separated Attribute Predictability (SAP) score

  • mostly focused on compactness.
  • considers the gap in prediction accuracy of a support vector machine trained on each feature in the representation.

Training

Training of Disentaglement ModelsTraining of WReNs

Train β-VAE, FactorVAE, β-TCVAE, and DIP-VAE on the panels from the modified dSprites and 3dshapes datasets

Train different WReN models where we control two factors:

  1. The representation produced by a specific model used to embed the input images.
  2. The hyperparameters of the WReN model.

10-dimensional latent space


Batch size: 64

Batch size: 32

Iterations: 300K

Iterations: 100K

Adam optimizer:

  • β1 = 0.9,
  • β2 = 0.999,
  • learning rate = 0.0001

Every 1000 iterations, we evaluate the accuracy on 100 mini-batches

Initial study

To assess the overall complexity of the abstract reasoning task.

When training the model, they considered 3 types of representations:

  • CNN representations which are learned from scratch (with the same architecture as in the disentanglement models) yielding standard WReN
  • Pre-trained frozen representations based on a random selection of the pre-trained disentanglement models
  • Directly using the ground-truth factors of variation (both one-hot encoded and integer encoded).

They train a full set of WReN models for each of the 360 representations from the disentanglement models

Results


Reconstruction

Odd columns show real samples and even columns their reconstruction. 3dshapes appears to be much easier than dSprites where disentangling the shape appears hard.

Initial study

Average accuracy of baselines, and models using pre-trained representations

We observe that the standard WReN model struggles to obtain good results. This is due to the fact that training from scratch is hard and runs may get stuck in local minima where they predict each of the answers with equal probabilities.

Given the pre-training and the exposure to additional unsupervised samples, it is not surprising that the learned representations from the disentanglement models perform better.

The WReN models that are given the true factors also perform well, already after only few steps of training.

Step 3 (Evaluation)

Rank correlation between various metrics and down-stream accuracy of the abstract visual reasoning models throughout training (i.e. for different number of samples)

In the few-sample regime (up to 20K steps) and across both data sets:

  • The BetaVAE score, and the FactorVAE score are highly correlated with accuracy.
  • The DCI Disentanglement score is correlated slightly less.
  • The MIG and SAP score exhibit a relatively weak correlation.

We observe strong evidence that disentangled representations yield better accuracy using relatively few samples, and we therefore conclude that they are indeed more sample efficient compared to entangled representations in this regard.

Conclusion


  1. These results provide concrete motivation why one might want to pursue disentanglement as a property of learned representations in the unsupervised case.

  2. Differences between disentanglement metrics observed, which should motivate further work in understanding what different properties they capture.

  3. It might be useful to extend the methodology in this study to other complex tasks, or include an investigation of other purported benefits of disentangled representations

References


[1] Cian Eastwood and Christopher K. I. Williams. A framework for the quantitative evaluation of disentangled representations. In International Conference on Learning Representations, 2018.

[2] Karl Ridgeway and Michael C Mozer. Learning deep disentangled embeddings with the f-statistic loss. In Advances in Neural Information Processing Systems, pages 185–194, 2018.

[3] Irina Higgins, Loic Matthey, Arka Pal, Christopher Burgess, Xavier Glorot, Matthew Botvinick, Shakir Mohamed, and Alexander Lerchner. beta-vae: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations, 2017.

[4] Hyunjik Kim and Andriy Mnih. Disentangling by factorising. In International Conference on Machine Learning.

[5] Tian Qi Chen, Xuechen Li, Roger B Grosse, and David K Duvenaud. Isolating sources of disentanglement in vaes. In Advances in Neural Information Processing Systems

[6] Abhishek Kumar, Prasanna Sattigeri, and Avinash Balakrishnan. Variational inference of disentangled latent concepts from unlabeled observations. In International Conference on Learning Representations, 2018.

[7] https://towardsdatascience.com/whats-new-in-deep-learning-research-an-iq-test-proves-that-neural-networks-are-capable-of-1afde3695c26

[8] https://www.assessment-training.com/raven-s-progressive-matrices-test

[9] Francesco Locatello, Stefan Bauer, Mario Lucic, Sylvain Gelly, Bernhard Schölkopf, and Olivier Bachem. Challenging common assumptions in the unsupervised learning of disentangled representations. In Proceedings of the 36th International Conference on Machine LearningVolume 97, 2018.



  • Keine Stichwörter