This is week 3 of Quintin's Alignment Papers Roundup. This week, I'm focusing on papers that use interpretability to guide a neural network's training process. A lot of alignment proposals seem to involve a step like this.
Explanatory interactive learning (XIL) is an active research area that uses various methods to generate some form of explanation for the model's decisions (often a heatmap of the most important portions of the input). XIL then optimizes the model to either make model explanations match human explanations, or to apply generic priors of what good explanations should look like (e.g., that they should be sparse).
Alignment proposals that use interpretability to steer model training usually imagine they have access to mechanistic interpretability methods that track a model's internal computations. Current XIL methods rely on easier, ad-hoc explanation methods, such as input saliency maps. Such methods can scale to supervise a full training process, unlike current mechanistic interpretability.
I expect many alignment researchers to think that input saliency methods are insufficient to properly supervise an AI's training process. Even if that's true, I think studying current XIL methods is valuable, simply because we can actually do empirical experiments with them. Even if there are no directly transferable insights we can gain from current XIL methods (unlikely, IMO), we can still learn about the "logistics" of doing general XIL research, such as the best ways to quantify how our interventions changed the trained models, what sort of protocols help humans to scalably use oversight tools, etc.
Additionally, current ad-hoc explanation methods are (very) imperfect, even for the more limited form of explanations they aim to provide. I expect that any future mechanistic interpretability methods that do scale to steering training processes will also be imperfect. Current XIL methods offer an empirical testbed to learn to wield imperfect and exploitable interpretability methods to shape a model's learning process.
Neural networks are among the most accurate supervised learning methods in use today, but their opacity makes them difficult to trust in critical applications, especially when conditions in training differ from those in test. Recent work on explanations for black-box models has produced tools (e.g. LIME) to show the implicit rules behind predictions, which can help us identify when models are right for the wrong reasons. However, these methods do not scale to explaining entire datasets and cannot correct the problems they reveal. We introduce a method for efficiently explaining and regularizing differentiable models by examining and selectively penalizing their input gradients, which provide a normal to the decision boundary. We apply these penalties both based on expert annotation and in an unsupervised fashion that encourages diverse models with qualitatively different decision boundaries for the same classification problem. On multiple datasets, we show our approach generates faithful explanations and models that generalize much better when conditions differ between training and test.
This is the first paper I know of that uses human saliency annotations to improve model training. I'm disappointed that they use LIME to validate the faithfulness of their saliency method, as opposed to approaches that I think are more robust like the deletion-based measure used here.
It's also interesting that their approach for finding diverse models is so similar to the approach independently discovered here and here of minimizing the similarities between the input gradients of multiple models.
Existing Visual Question Answering (VQA) methods tend to exploit dataset biases and spurious statistical correlations, instead of producing right answers for the right reasons. To address this issue, recent bias mitigation methods for VQA propose to incorporate visual cues (e.g., human attention maps) to better ground the VQA models, showcasing impressive gains. However, we show that the performance improvements are not a result of improved visual grounding, but a regularization effect which prevents over-fitting to linguistic priors. For instance, we find that it is not actually necessary to provide proper, human-based cues; random, insensible cues also result in similar improvements. Based on this observation, we propose a simpler regularization scheme that does not require any external annotations and yet achieves near state-of-the-art performance on VQA-CPv2.
This is the obligatory "empirically discovered improvements to neural net training processes may not work for the reason you initially assumed".
My guess is that it's beneficial to encourage neural nets to have sparse dependencies on their inputs, even without specific priors on what specific dependencies are best.
Edit: my current best guess as to why random saliency labels work is that they are regularizing the gradient norm of the model's behavior with respect to its inputs, leading to smoother model behavior.
Many past works aim to improve visual reasoning in models by supervising feature importance (estimated by model explanation techniques) with human annotations such as highlights of important image regions. However, recent work has shown that performance gains from feature importance (FI) supervision for Visual Question Answering (VQA) tasks persist even with random supervision, suggesting that these methods do not meaningfully align model FI with human FI. In this paper, we show that model FI supervision can meaningfully improve VQA model accuracy as well as performance on several Right-for-the-Right-Reason (RRR) metrics by optimizing for four key model objectives: (1) accurate predictions given limited but sufficient information (Sufficiency); (2) max-entropy predictions given no important information (Uncertainty); (3) invariance of predictions to changes in unimportant features (Invariance); and (4) alignment between model FI explanations and human FI explanations (Plausibility). Our best performing method, Visual Feature Importance Supervision (VisFIS), outperforms strong baselines on benchmark VQA datasets in terms of both in-distribution and out-of-distribution accuracy. While past work suggests that the mechanism for improved accuracy is through improved explanation plausibility, we show that this relationship depends crucially on explanation faithfulness (whether explanations truly represent the model's internal reasoning). Predictions are more accurate when explanations are plausible and faithful, and not when they are plausible but not faithful. Lastly, we show that, surprisingly, RRR metrics are not predictive of out-of-distribution model accuracy when controlling for a model's in-distribution accuracy, which calls into question the value of these metrics for evaluating model reasoning. All supporting code is available at this https URL
This paper shows human saliency annotations can help the trained models more than random saliency annotations, provided the saliency method actually reflects the model's decision making process. This underlines the importance of having saliency methods that actually reflect the model's decision making process, something which current saliency methods are pretty hit-or-miss at.
This paper is also interesting in that they actively optimize the model to make its decisions better conform to its saliency maps, which is a practice I've found to be frequent in the best performing XIL methods.
Deep reinforcement learning policies, despite their outstanding efficiency in simulated visual control tasks, have shown disappointing ability to generalize across disturbances in the input training images. Changes in image statistics or distracting background elements are pitfalls that prevent generalization and real-world applicability of such control policies. We elaborate on the intuition that a good visual policy should be able to identify which pixels are important for its decision, and preserve this identification of important sources of information across images. This implies that training of a policy with small generalization gap should focus on such important pixels and ignore the others. This leads to the introduction of saliency-guided Q-networks (SGQN), a generic method for visual reinforcement learning, that is compatible with any value function learning method. SGQN vastly improves the generalization capability of Soft Actor-Critic agents and outperforms existing stateof-the-art methods on the Deepmind Control Generalization benchmark, setting a new reference in terms of training efficiency, generalization gap, and policy interpretability.
This paper reports shockingly large gains in generalization and robustness to out of distribution perturbations.
It doesn't rely on human labels to identify important features. It trains the value function estimator to depend only on high-salience pixels and trains the network to predict its own saliency scores. These two regularizers apparently lead to much sparser saliency maps that match human priors for what's important in the task and vastly improved generality / robustness.
The improvements reported by this paper are so strong that I suspect some sort of confounder is at play. If not, this paper probably represents a significant advance in the state of the art for robust RL.
Both regularizers seem like they should increase the coupling between the saliency maps and the model's behaviors. This leads to sparser, more human-like saliency maps, despite not explicitly using human supervision of the saliency. The authors think this means the resulting models are more interpretable. Hopefully, the models actually depend on their saliency maps, such that we can supervise the training process by intervening on the saliency maps.
Saliency Guided Adversarial Training for Learning Generalizable Features with Applications to Medical Imaging Classification System
This work tackles a central machine learning problem of performance degradation on out-of-distribution (OOD) test sets. The problem is particularly salient in medical imaging based diagnosis system that appears to be accurate but fails when tested in new hospitals/datasets. Recent studies indicate the system might learn shortcut and non-relevant features instead of generalizable features, so-called good features. We hypothesize that adversarial training can eliminate shortcut features whereas saliency guided training can filter out non-relevant features; both are nuisance features accounting for the performance degradation on OOD test sets. With that, we formulate a novel model training scheme for the deep neural network to learn good features for classification and/or detection tasks ensuring a consistent generalization performance on OOD test sets. The experimental results qualitatively and quantitatively demonstrate the superior performance of our method using the benchmark CXR image data sets on classification tasks.
This paper mixes adversarial training and XIL on medical image classifications, a difficult domain where confounders are common. It seems like a good testbed for alignment approaches hoping to combine these methods.
This paper masks out low-saliency features of the input images, then adversarially optimizes the partially masked image. They then minimize KL divergence between model predictions on the adversarial partially masked images and on clean images. I thus count this paper as another example of training a model to match the saliency method.
Neural language models' (NLMs') reasoning processes are notoriously hard to explain. Recently, there has been much progress in automatically generating machine rationales of NLM behavior, but less in utilizing the rationales to improve NLM behavior. For the latter, explanation regularization (ER) aims to improve NLM generalization by pushing the machine rationales to align with human rationales. Whereas prior works primarily evaluate such ER models via in-distribution (ID) generalization, ER's impact on out-of-distribution (OOD) is largely underexplored. Plus, little is understood about how ER model performance is affected by the choice of ER criteria or by the number/choice of training instances with human rationales. In light of this, we propose ER-TEST, a protocol for evaluating ER models' OOD generalization along three dimensions: (1) unseen datasets, (2) contrast set tests, and (3) functional tests. Using ER-TEST, we study three key questions: (A) Which ER criteria are most effective for the given OOD setting? (B) How is ER affected by the number/choice of training instances with human rationales? (C) Is ER effective with distantly supervised human rationales? ER-TEST enables comprehensive analysis of these questions by considering a diverse range of tasks and datasets. Through ER-TEST, we show that ER has little impact on ID performance, but can yield large gains on OOD performance w.r.t. (1)-(3). Also, we find that the best ER criterion is task-dependent, while ER can improve OOD performance even with limited and distantly-supervised human rationales.
This is a paper whose evaluation criteria seem more impressive than their actual results, which seem kind of hit or miss to me. I'm always glad to see papers that drill down to understand what exactly their method changes about the model's different capabilities and how the model generalizes to out of distribution test data.
The paper does act as reasonably strong evidence that the primary gains of XIL appear on out of distribution tests, not in-distribution tests.
Many interpretability tools allow practitioners and researchers to explain Natural Language Processing systems. However, each tool requires different configurations and provides explanations in different forms, hindering the possibility of assessing and comparing them. A principled, unified evaluation benchmark will guide the users through the central question: which explanation method is more reliable for my use case? We introduce ferret, an easy-to-use, extensible Python library to explain Transformer-based models integrated with the Hugging Face Hub. It offers a unified benchmarking suite to test and compare a wide range of state-of-the-art explainers on any text or interpretability corpora. In addition, ferret provides convenient programming abstractions to foster the introduction of new explanation methods, datasets, or evaluation metrics.
"Will You Find These Shortcuts?" A Protocol for Evaluating the Faithfulness of Input Salience Methods for Text Classification
Feature attribution a.k.a. input salience methods which assign an importance score to a feature are abundant but may produce surprisingly different results for the same model on the same input. While differences are expected if disparate definitions of importance are assumed, most methods claim to provide faithful attributions and point at the features most relevant for a model's prediction. Existing work on faithfulness evaluation is not conclusive and does not provide a clear answer as to how different methods are to be compared. Focusing on text classification and the model debugging scenario, our main contribution is a protocol for faithfulness evaluation that makes use of partially synthetic data to obtain ground truth for feature importance ranking. Following the protocol, we do an in-depth analysis of four standard salience method classes on a range of datasets and shortcuts for BERT and LSTM models and demonstrate that some of the most popular method configurations provide poor results even for simplest shortcuts. We recommend following the protocol for each new task and model combination to find the best method for identifying shortcuts.
I link these two papers because unfaithful saliency maps seem like the main bottleneck in current XIL practices. These papers seem like they offer good tools for evaluating the faithfulness of saliency methods.
Explainable Artificial Intelligence (XAI) is an emerging research field bringing transparency to highly complex and opaque machine learning (ML) models. Despite the development of a multitude of methods to explain the decisions of black-box classifiers in recent years, these tools are seldomly used beyond visualization purposes. Only recently, researchers have started to employ explanations in practice to actually improve models. This paper offers a comprehensive overview over techniques that apply XAI practically for improving various properties of ML models, and systematically categorizes these approaches, comparing their respective strengths and weaknesses. We provide a theoretical perspective on these methods, and show empirically through experiments on toy and realistic settings how explanations can help improve properties such as model generalization ability or reasoning, among others. We further discuss potential caveats and drawbacks of these methods. We conclude that while model improvement based on XAI can have significant beneficial effects even on complex and not easily quantifyable model properties, these methods need to be applied carefully, since their success can vary depending on a multitude of factors, such as the model and dataset used, or the employed explanation method.
A very recent and pretty extensive review paper of different XIL methods, for those interested in a broader perspective on the field. XIL is a surprisingly large field given how rarely I hear mention of it in alignment circles. This roundup focused on methods that supervise a model's input saliency maps (an approach the review calls "augmenting the loss"), but there are other XIL approaches as well.
My main update after looking through the XIL literature is that it's probably a good idea to actively optimize your model to make it better conform to your saliency method. My original thinking had been that you'd want to use a saliency method that was correct regardless of the model's current parameters. Most XIL papers do not regularize the model to match the saliency method, and they usually have relatively small performance improvements on realistic tasks.
In retrospect, it makes sense that you can optimize the model to be better explained by your interpretability method. I expect it's possible to optimize a model to make it harder to interpret, and it's unlikely that random initialization + vanilla training makes a model maximally interpretable.
Good saliency heatmaps should satisfy certain properties. E.g., disrupting low-salience portions of the input should have less of an effect on the model's behavior compared to disrupting high-salience portions of the input. Optimizing the model to ensure these properties are satisfied doesn't ensure your saliency method is always correct, but it can rule out many definitely incorrect behaviors from your saliency method.
Of course, we should also look for saliency methods that are more robust, even without having to adapt the model to the method. It seems like the primary bottleneck in this respect is evaluating the actual robustness of a saliency method.
One idea I had for such an evaluation is to use the candidate saliency method for "negative XIL". Suppose we have a classification dataset and annotations that cover all possible features that the model could use to perform the classifications, such that there's no way to get good performance without paying attention to the labeled features. Then, we use XIL to penalize the model for attending to the labeled features, scaling the penalty term as needed until the saliency method assigns very low importance to the labeled features.
This training process directly incentivizes SGD to smuggle information past the saliency method, because that's the only way to get higher classification accuracy. We can thus rank the robustness of the saliency methods by the degree to which they decrease the model's performance. It essentially measures the degree to which the saliency method can constrain SGD away from lower loss solutions.
Anyways, I hope readers find these papers useful for their own research. Please feel free to discuss the listed papers in the comments or recommend additional papers to me.
For next week's roundup, I'm thinking the focus will be on SGD inductive biases.
My other candidate focuses are:
- Shape versus texture bias in neural nets / humans
- Input saliency methods
- Diffusion models
- Controllable text generation
- Techniques for chain of thought language models
- Structure and content of language model internal representations
Let me know if there are any topics you're particularly interested in.