Chen, C., Li, O., Tao, C., Barnett, A.J., Su, J. and Rudin, C., NeurIPS 2019

Written and presented by Christoph Berger, tutored by Dr. Seong Tae Kim



Motivation

A clay colored sparrow

How would you explain to someone who has no idea about ornithology that the bird above is a clay coloured sparrow? You might go about this task by pointing out prototypical features of the bird, such as the colour of the feathers and the white breast, the shape of the head and the small beak as well as the legs. Humans reason that way when they have to explain their decisions in difficult classification tasks - such as in medicine or in cases like this, bird species classification. 


The proposed method by Chen et al. attempts to integrate this kind of reasoning using prototypical parts of images into a classification network. Since the prototypical parts of the image are directly used for the classification decision, this also provides a previously unseen level of interpretability which can be used in many fields - especially the potential in medical imaging should be explored further.

Related Work

In general, there are two types of interpretability: posthoc interpretability, which is the attempt to take a trained model and understand its decisions without modifying it by simply looking at combinations of inputs and outputs, and built-in interpretability, which builds models specifically to make them understandable for humans. 

Posthoc Interpretability includes methods such as class-specific activation maximisation [1-7] and input-specific visualisations, for example deconvolution [8] or gradient-based saliency visualisation methods [9-12]. 

The presented work relates closer to other methods for built-in interpretability, such as attention models. However, class activation maps [13] or models which include part-attention [14-16] can only tell us which parts of the input the model is looking at, but not how this attention influences the decision. In these approaches, the actual classification decision is still based on hidden characteristics which the observer has to figure out on her own. 

The closest related works are prototype classification techniques, which often use scale invariant feature transforms (SIRT) for feature extraction instead of a neural network as well as different pipelines for feature extraction and subsequent classification [17]. Thus, no end-to-end training is possible. Another very similar work is the proposed case-based reasoning model of Li et al. [18] which uses whole images as prototypes and needs a decoder to visualise the learned prototypes. This makes for poor visual quality of the prototypical illustrations. 

Main Contributions

The proposed method addresses the shortcomings of other approaches in a new network, called ProtoPNet, which is end-to-end trainable and learns a certain number of prototypes per class. These prototypes are then used to decide which class a test image belongs to while enabling interpretability, as the prototypical patches can be visualised as training image patches, so there is no loss in visual quality. Since only the prototypes are used for the classification decision, the network retains its high level of understandability for humans.


Architecture



The model is split into three layers: 1. A convolutional neural network for prototype extraction, 2. A prototype layer for comparison of the prototypes with image patches and 3. A fully connected layer which takes the similarity scores of the prototypes and computes a final classification score.

The CNN layers in the first step are the convolutional layers of well-established image classification networks, such as VGG [24], DenseNet [25] and ResNet [26], all of which are pretrained on ImageNet [23]. Two additional convolutional layers with an output of dimensions 7 x 7 x D (with D being either 128, 256 or 512) are added to the existing layers. This output represents the patches in latent space being used as prototypes. D is learned as a hyperparameter using cross-validation. 

In the prototype layer, the test image patches are compared to the learned prototypical patches and a similarity score is computed using global max pooling. Is this score is high, it simply means that there is some patch in the image which is similar to the prototype, not where this patch is exactly. 

Finally, a fully connected layer is used to sum up all the relevant scores and compute a final per-class score, with the highest score representing the most likely class.

Training

Training is done in three stages: Stochastic Gradient Descent for the convolutional layers, projection of the prototypes and a convex optimisation of the last layer.

Stochastic Gradient Descent

During this stage of training, a meaningful latent space is learned to represent the prototypes. The function to be optimised is the following: 

{min}_{P, {\alpha}_{conv}}CrossEntropy(h \circ g_p \circ f(X), Y) + \lambda_1Clst(P, X, Y)+ \lambda_2Sep(P, X,Y)

The standard corss entropy term encourages accuracy and penalises misclassifications, while the cluster term encourages similarity to at least one prototype of the own (correct) class. Additionally, the separation term tries to maximise distance from prototypes of other classes. See the image below for a visual example of the result of this loss function.


Projection of the Prototypes

This step is done after the initial training of the convolutional layers and is mainly used for representation of the prototypes. Each prototype is pushed onto the closest latent representation of all training image patches from the same class, which means that the filter is forced to be equal to the latent representation of the closest training patch. This enables a direct visualisation using image patches from the training data while still retaining nearly equal accuracy when compared to the learned latent representations. The process looks like this:

Z_j = \{\tilde{z}:\tilde{z} \in patches(f(x_i)) \textrm{ for all $i$ with $y_i$} = k, \textrm{ where $k$ satisfies } p_j \in P_k\}

Convex Optimisation of the Last Layer

The last layer connections are optimised for the final classification using cross entropy and a term that encourages sparsity. The sparsity term encourages a positive reasoning process, such that a final decision is obtained by reasoning "the prototypes are very likely to class X, this is why it is class X" rather than "the prototype for class X does not fit, that's why it has to be Y". Additionally, since all parameters of earlier stages are fixed in this step, a global minimum can be obtained and the optimisation is convex.

{min}_{w_h}\frac{1}{h}\sum^{n}_{i=1} CrossEntropy(h \circ g_p \circ f(x_i), y_i)+\lambda\sum_{k=1}^{K}\sum_{j:p_j \notin P_k}|w^{(k,j)}_h|


Results

The following results have been obtained by using the CUB-200-2011 bird species dataset, which includes 200 classes with about 30 images each. The image size is 336x336 pixels and the number of prototypes for training has been fixed to 10 per class. Standard offline training data augmentation with rotation, skew, shear, distortion and flips was employed to bring up to number of images per class to about 1200. 


The graph above shows a selection of methods the proposed method has been compared with. The main takeaway is that the proposed method achieves comparable accuracy as most of the state-of-the-art methods while providing more interpretability.

One thing to note for the reported scores is that the authors use an ensemble of convolutional networks to find more and more meaningful prototypes. This backbone is varied throughout the experiments to obtain the optimal results each time. The backbones used are:

  • VGG16 + DenseNet121 + DenseNet161 (for full images)
  • VGG16 + ResNet34 + DenseNet121 (for bounding boxes)

This ensemble approach does not impact interpretability, it simply leads to more prototypes per class and longer training times.

Below follow three selected figures from the paper which show that one can clearly see which prototype had the highest activation for a given test image and how it influenced the final decision. 

The example below shows an erroneous classification which clearly shows the reasons for this mistake - the highly similar feather pattern was taken into account more prominently than the other features, which led to the wrong classification.

The Medical Case Study

In the first version of the paper, the authors also included a medical case study to show the applicability of their model to real-world challenges. This is especially relevant due to the fact that interpretability is a nice-to-have in uncritical applications, such as classifying a bird, but highly important in decisions in high-stakes areas, such as medical imaging. Wrong decisions can lead to high costs for the medical system, impact the patient's life due to increased time needed for follow-ups and ruling out wrong decisions and even lead to adverse medical effects in patients. Since most AI systems for medical applications in the near to mid future still require a human in the loop and are seen as an aid to the physician, understandability is crucial for such tasks. 

The example below shows the application of the proposed method to the region of interest of a mammography image where the task is to classify the ROI as either benign or malicious. The images are taken from the well-known CBIS-DDSM dataset [30-32] and come pre-cropped to the ROI. The paper shows that the network is able to reach a classification accuracy of 82.6% which is nearly on par with non-interpretable models (e.g. VGG-16 reaches 84% accuracy on this dataset) while still being able to explain its decisions. However, when you look at the example below, you can see that it apparently compares the entire ROI to the test image which would mean that prototypes are whole ROI. This somehow weakens explainability in my opinion as it only looks at entire regions rather than characteristic parts of the image. This being said, I cannot speak for the authors and as they only explained this experiment very superficially in their paper, the reader lacks the information as to how the approach exactly works on this dataset.

This might also be one of the reasons why this particular dataset was not included in the accepted version of the paper at NeurIPS 2019 and was replaced by a car model classification dataset. 

Conclusion

The presented methods enables case-based reasoning for classification while taking exclusively the interpretable features into account which has been absent for other methods. By visualising the learned prototypes using training image patches, it can inherently explain its decision. Additionally, the separation term introduced in learning the prototypes leads to a more distinct latent space. 

More research in medical applications is needed to validate the approach in those areas as the mammography case study introduced in the preprint of this paper is not up to the same standard as the rest of the paper. In the NeurIPS 2019 paper, there is an additional study performed with a car model classification dataset, however, the results are very similar and thus not included in this summary post.

Additionally, the authors made their code available here

Presentation

Download the presentation here: https://drive.google.com/open?id=129rYAKJqGQqe2bLMFkfLF65M_8dk1_Fa

References

[1] D. Erhan, Y. Bengio, A. Courville, and P. Vincent. Visualizing Higher-Layer Features of a Deep Network. Technical Report 1341, the University of Montreal, June 2009. Also presented at the Workshop on Learning Feature Hierarchies at the 26th International Conference on Machine Learning (ICML 2009), Montreal, Canada.

[2] G. E. Hinton. A Practical Guide to Training Restricted Boltzmann Machines. In Neural Networks: Tricks of the Trade, pages 599–619. Springer, 2012.

[3] H. Lee, R. Grosse, R. Ranganath, and A. Y. Ng. Convolutional Deep Belief Networks for Scalable Unsupervised Learning of Hierarchical Representations. In Proceedings of the 26th International Conference on Machine Learning (ICML), pages 609–616, 2009.

[4] J. R. Uijlings, K. E. Van De Sande, T. Gevers, and A. W. Smeulders. Selective Search for Object Recognition. International Journal of Computer Vision, 104(2):154–171, 2013.

[5] A. Nguyen, A. Dosovitskiy, J. Yosinski, T. Brox, and J. Clune. Synthesizing the preferred inputs for neurons in neural networks via deep generator networks. In Advances in Neural Information Processing Systems 29 (NIPS), pages 3387–3395, 2016.

[6] K. Simonyan, A. Vedaldi, and A. Zisserman. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps. In Workshop at the 2nd International Conference on Learning Representations (ICLR Workshop), 2014.

[7] J. Yosinski, J. Clune, T. Fuchs, and H. Lipson. Understanding Neural Networks through Deep Visualization. In Deep Learning Workshop at the 32nd International Conference on Machine Learning (ICML), 2015.

[8] M. D. Zeiler and R. Fergus. Visualizing and Understanding Convolutional Networks. In Proceedings of the European Conference on Computer Vision (ECCV), pages 818–833, 2014.

[9]  K. Simonyan, A. Vedaldi, and A. Zisserman. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps. In Workshop at the 2nd International Conference on Learning Representations (ICLR Workshop), 2014.

[10] M. Sundararajan, A. Taly, and Q. Yan. Axiomatic Attribution for Deep Networks. In Proceedings of the 34th International Conference on Machine Learning (ICML), volume 70 of Proceedings of Machine Learning Research, pages 3319–3328. PMLR, 2017.

[11]  D. Smilkov, N. Thorat, B. Kim, F. Viégas, and M. Wattenberg. SmoothGrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825, 2017.

[12]  R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra. Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), Oct 2017.

[13] B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba. Learning Deep Features for Discriminative Localization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 2921–2929. IEEE, 2016.

[14] H. Zheng, J. Fu, T. Mei, and J. Luo. Learning Multi-Attention Convolutional Neural Network for Fine- Grained Image Recognition. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), pages 5209–5217, 2017.

[15] N.Zhang,J.Donahue,R.Girshick,andT.Darrell.Part-basedR-CNNsforFine-grainedCategoryDetection. In Proceedings of the European Conference on Computer Vision (ECCV), pages 834–849. Springer, 2014.

[16] S. Huang, Z. Xu, D. Tao, and Y. Zhang. Part-Stacked CNN for Fine-Grained Visual Categorization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 1173–1182, 2016.

[17] B. Kim, C. Rudin, and J. Shah. The Bayesian Case Model: A Generative Approach for Case-Based Reasoning and Prototype Classification. In Advances in Neural Information Processing Systems 27 (NIPS), pages 1952–1960, 2014.

[18] O. Li, H. Liu, C. Chen, and C. Rudin. Deep Learning for Case-Based Reasoning through Prototypes: A Neural Network that Explains Its Predictions. In Proceedings of the Thirty-Second AAAI Conference on Artificial Intelligence (AAAI), 2018.

[19] Ming, P. Xu, H. Qu, and L. Ren. Interpretable and Steerable Sequence Learning via Prototypes. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD’19), pages 903–913. ACM, 2019.

[20] T.-Y. Lin, A. RoyChowdhury, and S. Maji. Bilinear CNN Models for Fine-grained Visual Recognition. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), pages 1449–1457, 2015.

[21] J.Fu,H.Zheng,andT.Mei.LookClosertoSeeBetter:RecurrentAttentionConvolutionalNeuralNetwork for Fine-grained Image Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 4438–4446, 2017.

[22] H. Zheng, J. Fu, T. Mei, and J. Luo. Learning Multi-Attention Convolutional Neural Network for Fine- Grained Image Recognition. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), pages 5209–5217, 2017.

[23] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei. ImageNet: A Large-Scale Hierarchical Image Database. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 248–255. IEEE, 2009.

[24] K. Simonyan and A. Zisserman. Very Deep Convolutional Networks for Large-Scale Image Recognition. In Proceedings of the 3rd International Conference on Learning Representations (ICLR), 2015.

[25] G. Huang, Z. Liu, L. Van Der Maaten, and K. Q. Weinberger. Densely Connected Convolutional Networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 4700–4708, 2017.

[26] K. He, X. Zhang, S. Ren, and J. Sun. Deep Residual Learning for Image Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, 2016.

[27] N.Zhang,J.Donahue,R.Girshick,andT.Darrell.Part-basedR-CNNsforFine-grainedCategoryDetection. In Proceedings of the European Conference on Computer Vision (ECCV), pages 834–849. Springer, 2014.

[28] B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba. Learning Deep Features for Discriminative Localization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 2921–2929. IEEE, 2016.

[29] D. Wang, Z. Shen, J. Shao, W. Zhang, X. Xue, and Z. Zhang. Multiple Granularity Descriptors for Fine-grained Categorization. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), pages 2399–2406, 2015.

[30] K.Clark,B.Vendt,K.Smith,J.Freymann,J.Kirby,P.Kop- pel, S. Moore, S. Phillips, D. Maffitt, M. Pringle, et al. The cancer imaging archive (tcia): maintaining and operating a public information repository. Journal of digital imaging, 26(6):1045–1057, 2013.

[31] R. S. Lee, F. Gimenez, A. Hoogi, and D. Rubin. Curated breast imaging subset of ddsm. The Cancer Imaging Archive, 2016.

[32] R.S.Lee,F.Gimenez,A.Hoogi,K.K.Miyake,M.Gorovoy, and D. L. Rubin. A curated mammography data set for use in computer-aided detection and diagnosis research. Scientific data, 4:170177, 2017






  • Keine Stichwörter