Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

This post is written as part of the SERI MATS program.

This post explores how model interpretability has changed over time, and if the existing approaches to interpretability are likely to remain viable as the models increase in size and capability. I examine several approaches to interpretability:

  • simulatability;
  • feature engineering;
  • model-based interpretability;
  • post-hoc interpretability.

Simulatability

Before the rise of machine learning models were interpretable, because they were simulatable: that is, a human could reproduce the model’s decision making process in their head and predict its behavior on new inputs. This can be done when the model is simple and the input is low-dimensional: a small decision tree, a linear regression, or a set of rules. Sometimes it turns out that a simple algorithm can solve the problem just as well as a more complicated one.

For example, in the 2018 Explainable Machine Learning challenge, participants were asked to train a model to predict credit risk scores, and produce an explanation for it. Instead of training a black-box model and producing a post-hoc explanation, the winning team trained an interpretable two layer additive model. It was understandable because it was simple, but it had competitive accuracy.

However, we cannot expect simple models to achieve high accuracy when the underlying relationship is complicated. Tracking computations in one's head will become impossible with high-dimensional inputs, such as images. When we need more complicated models, simulatability stops working.

Feature engineering

The relationship between the target variable and the raw input data may be complicated. But expert domain knowledge can sometimes be used to construct useful and meaningful features and greatly simplify the relationship the model has to learn, making it easier to interpret.

For example, in [1],  the authors build a model for cloud detection in arctic satellite imagery. Using climate science knowledge and exploratory data analysis, authors design 3 simple features that prove sufficient to classify clouds with high accuracy using quadratic discriminant analysis.

Feature engineering works equally well for any model size. However, it relies on human expert knowledge or engineering effort to construct features for every domain. The trend in machine learning [2] has been that general methods that leverage computation eventually outperform domain-specific feature engineering. As we approach AGI, expecting our models to generalize to new domains and outperform human experts, this method will become less and less applicable. Unless, of course, the task of finding understandable and useful features can be outsourced to the AI itself.

Model-based interpretability.

This approach means training a model that is understandable by design.

For example, Chaofan Chen et al, 2018, [3] develop an image classification neural network that reasons about images similarly to how a person would. Their ProtoPNet model dissects training images to find prototypical parts in each class and classifies new images by comparing their parts with prototypical image parts. Their model achieves 80% accuracy on the CUB-200-2011 dataset, while the SOTA accuracy was 89% in 2018  (92% now). 

Is there a tradeoff between interpretability and accuracy? Sometimes interpretability helps researchers correct mistakes and thereby improve accuracy. For example, in a 2015 study [4] researchers aimed to predict mortality risks of 14199 patients with pneumonia from data about their demographics, heart rate, blood pressure and laboratory tests. They used a generalized additive model with pairwise interactions:

With a model like this, dependence of the end result on each individual feature or pair of features can be easily visualized as a graph or a heatmap, respectively. Studying these visualizations helped researchers discover a number of counterintuitive properties of the model. For instance, the model predicted that having asthma is associated with a lower mortality risk. In reality, having asthma increases your risk of death from pneumonia. Knowing this, hospitals provided aggressive care to all patients with asthma, which improved their chances of survival relative to the general population. 

The model was meant to be used for determining hospitalization priority. Deprioritizing asthma patients would be a mistake. Using an interpretable model allowed the researchers to discover and fix errors like this one.

However, the top models on ML leaderboards are usually not inherently interpretable, suggesting there is a tradeoff. This tradeoff could exist for two reasons:

  • more interpretable models with comparable accuracy don’t exist;
  • more interpretable models exist, but we don’t know how to find them.

There is a technical argument why more interpretable models with comparable accuracy should exist [5]: the Rashomon set argument. Given a prediction problem, the Rashomon set is defined as the set of predictive models with accuracy close to optimal. If this set is large, it probably contains some simple functions.

What is the reason to believe this set is large? D’Amour et al [6] conducted a large-scale study of Rashomon effects in production-scale machine learning systems. They examined several image classifiers, natural language processing systems, and clinical risk prediction systems. They developed several “stress tests” to measure desirable properties of model behavior that were not captured by the evaluation metric, such as fairness to different subpopulations, out-of-distribution generalization, and stability of model’s prediction across irrelevant perturbations of the input. Then they trained several versions of each model, differing only by the random weight initialization at the start of the training. As expected, the performance of these models on the test set was nearly identical. But performance on the “stress tests” varied, sometimes an order of magnitude more than the test set performance. The experiments showed that the performance on a secondary objective was not predetermined by the training set and model architecture, and can be improved by optimizing for it without losing the primary objective accuracy. 

Of course, this argument only proves that we can somewhat improve an interpretability measure without losing accuracy. It could be that in order to actually make the model understandable we have to improve it a lot more, so there is, after all, an accuracy-interpretability tradeoff. 

Zijun Zhang et al [7] ran a neural architecture search optimizing simultaneously for accuracy and knowledge similarity between learned model weights and expert prior knowledge in genomic sequence analysis. They discover that with appropriate knowledge weight (between 0.001 and 1) in the loss function, the trained model retains predictive accuracy and dramatically increases in interpretability. With a smaller weight, the models become less interpretable. With weights bigger than one, the model accuracy drops.

Modularity

One way to make an architecture more interpretable is to make it modular. The generalized additive model discussed above was an example of a modular architecture. It seems like existing neural networks already have some modular structure. D. Filan et al [8] discover that trained neural networks have more clusters than randomly initialized ones. R. Csordás et al [9] find that it is possible to identify modules in neural networks, and networks use different modules to perform different functions. However, they fail to reuse a module to perform the same function, and may end up learning the same function multiple times.

Studying convolutional neural network filters, OpenAI team discovered that image classification networks exhibit branch specialization. They found that networks consistently learn two groups of features: black-and-white Gabor filters and low-frequency color detectors. These features are strongly connected within a group and weakly connected between groups.

When they train AlexNet, a network whose architecture has two branches by design, Gabor filters end up in one branch and color detectors in another branch. I think this suggests that neural networks already have some modularity in them, and it should be possible to enforce modularity without losing a lot of accuracy.

How should we expect these methods to fare in the future? As model sizes grow, the models become more underdetermined, so we should have more degrees of freedom to find interpretable architectures. On the other hand, the search space becomes bigger, so it is not clear if the search results will become better. Domain expertise may be needed to define the interpretability loss function in each domain. For example, sparsity often improves understandability of models working with tabular data, but an image classification model will not become more understandable if it only looks at a few pixels. I expect more research will improve our ability to find interpretable architectures.

A recent study by Vincent Margot et al [10] can be considered as a test of this assumption. They measure and compare interpretability of several rule-based algorithms. They combine measures of predictivity, stability and simplicity into an interpretability score. They compare a RegressionTree (RT) (1984), RuleFit (2008), NodeHarvest (2010), Covering Algorithm (CA) (2021) and SIRUS (2021) algorithms for regression tasks, and RIPPER (1995), PART (1998) and Classification Tree (CART) (1984) for classification tasks. All models have the same accuracy on the test sets.

For regression tasks, the recent SIRUS algorithm comes out as the winner:

For classification tasks, the most recent PART algorithm is dominated by CART and RIPPER.

So, while there is no clear trend of interpretability rising or dropping over time, some recently developed models, like SIRUS, are also more interpretable. 

Post-hoc interpretability

Finally, there is the approach of taking any trained neural network and trying to understand and reverse-engineer its workings. Some approaches (e.g. LIME, PDP, ICE, ALE) to post-hoc interpretability are model-agnostic. They treat the model as a black box and try to approximate it with a simpler model or find patterns in its predictions. I don’t believe these are promising for solving the mesaoptimization problem, because we cannot judge whether the model is an optimizer just from its behavior. Any behavior can be produced by a giant lookup-table. We have to look inside the model.

OpenAI team was successful at looking inside CNN models and understanding functions of several kinds of filters: Gabor filters, high-low frequency detectors, curve detectors, color change detectors. How would we quantify this kind of interpretability?

David Bau et al [11] use Broden, a dataset where each object in the picture is segmented and labeled, to measure correspondence between model filters and human concepts. They find that interpretability of ResNet > VGG > GoogLeNet > AlexNet: deeper architectures allow greater interpretability.

Iro Laina et al [12] measure mutual information between human concepts and learned representations in self-supervised representation learning image models. Again it looks like there is a correlation between the normalized mutual information (NMI) and accuracy (Top-1) scores: 

So, it looks like the hypothesis is supported by the research I could find.

[1]  Shi, Tao, et al. "Daytime arctic cloud detection based on multi-angle satellite data with case studies." Journal of the American Statistical Association 103.482 (2008): 584-593.

[2] http://www.incompleteideas.net/IncIdeas/BitterLesson.html

[3] Chen, Chaofan, et al. "This looks like that: deep learning for interpretable image recognition." Advances in neural information processing systems 32 (2019).

[4] Caruana R., et al., “Intelligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission” in Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, Cao L., Zhang C., Eds. (ACM, New York, NY, 2015), pp. 1721–1730.

[5] Rudin, C. Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nat Mach Intell 1, 206–215 (2019). https://doi.org/10.1038/s42256-019-0048-x

[6] D'Amour, Alexander, et al. "Underspecification presents challenges for credibility in modern machine learning." arXiv preprint arXiv:2011.03395 (2020).

[7] Zhang, Z. (2019). Neural Architecture Search for Biological Sequences. UCLA. ProQuest ID: Zhang_ucla_0031N_18137. Merritt ID: ark:/13030/m5mw7mfv.

[8] Filan, Daniel, et al. "Pruned neural networks are surprisingly modular." arXiv preprint arXiv:2003.04881 (2020).

[9] Csordás, Róbert, Sjoerd van Steenkiste, and Jürgen Schmidhuber. "Are neural nets modular? inspecting functional modularity through differentiable weight masks." arXiv preprint arXiv:2010.02066 (2020).

[10] Margot, Vincent, and George Luta. "A new method to compare the interpretability of rule-based algorithms." AI 2.4 (2021): 621-635.

[11] David Bau, Bolei Zhou, Aditya Khosla, Aude Oliva, Antonio Torralba: Network Dissection: Quantifying Interpretability of Deep Visual Representations. CVPR 2017: 3319-3327

[12] Laina, Iro, Yuki M. Asano, and Andrea Vedaldi. "Measuring the interpretability of unsupervised representations via quantized reversed probing." (2022).


 

 


 

New Comment