Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.
Interpreting Models by Ablation. Image generated by DALL-E 3.

Introduction

Interpretability in machine learning, especially in language models, is an area with a large number of contributions. While this can be quite useful for improving our understanding of models, one issue is that there is the lack of robust benchmarks to evaluate the efficacy of different interpretability techniques. Drawing comparisons and determining their true effectiveness in real-world scenarios becomes a difficult task.

Interestingly, there exists a parallel in the realm of non-language models under the research umbrella of Machine Unlearning. In this field, the objective is twofold: firstly, to deliberately diminish the model's performance on specified "unlearned" tasks, and secondly, to ensure that the model's proficiency is maintained or even enhanced on certain "retained" tasks. The inherent challenge here is achieving a balance between these seemingly opposing goals, and thus comes with a range of metrics for measuring the effectiveness of the techniques.

Drawing inspiration from Machine Unlearning, I believe that the metrics developed in this space could potentially serve as a litmus test for interpretability techniques in language models. By applying interpretability techniques as unlearning strategies, we can better test the effectiveness of interpretability methods, essentially setting benchmarks for how well these techniques can steer language models in desired directions.

If we aspire to have truly interpretable models, we must not only develop sophisticated techniques, but also robust benchmarks against which these techniques can be validated. Machine Unlearning might just offer the rigorous testing ground we need.

The rest of this post will: 1) Give a brief overview of Machine Unlearning, 2) Give a brief list of Machine Unlearning metrics, and how they may be applicable, 3) Give a deeper dive on each of the metrics, 4) Discuss how these fit in with existing metrics in Interpretability.

 

Machine Unlearning

Many papers in the subfield of Machine Unlearning are motivated by privacy preservation, and pose the question: "If we trained on someone's information that is now retracted, how can we remove that information without needing to retrain the whole model?"

There are multiple ways you might achieve unlearning. The "ideal/standard" is often to train the model again, but without the data you don't want it to learn. Two of the main ideals for an unlearned model are:

  1. You want the unlearned model to act exactly like this re-trained model
  2. You want the model to behave like a randomly initialised model on the unlearned task, and like the original model on retained task.

Typically for Machine Unlearning, people want the first ideal. It may seem non-obvious that we should care about this distinction, but people do care, as you don't want to "Goodhart" the unlearning process. If the model behaves in the second way, and this differs to the first, you may instead be adding a circuit that identifies your unlearned training set and just adds randomness.

For interpretability, it might be less concerning to differentiate between these ideals unless gradient-based techniques that explicitly optimize for machine unlearning are employed. One main thing to keep in mind, is that if you train on A and not B, then the model might still learn some things that are useful for making predictions about B.

 

It may be the case that in some neural network architectures, unlearning may be more or less difficult and knowledge may be more or less entangled. Unlearning one piece of information might inadvertently affect the retention of other unrelated information. It would be ideal if we could measure the degree to which this is the case, and avoid making systems where one could disentangle various pieces of knowledge.

Overview of Terminology:

Here is some terminology often used in the machine unlearning literature. (note that there can be some minor differences in use):

  • Forgotten/Unlearned task: task or knowledge you want the model to forget.
  • Retained task: task or knowledge you want to have the model stay good at. (i.e: the entire dataset except for the unlearned task).
  • Original model: the base model that you start off with.
  • Unlearned model: the model after the machine unlearning technique is applied. This model should be worse at some "unlearned" task, but should still be good at the "retained" task.
  • Relearned model: train the unlearned model to do the unlearned task again. 
  • Retrained model: train a randomly initialised model from scratch on the whole dataset, excluding the task you don't want it to do (ie: only on retained tasks). Can be very expensive for large models.
  • Streisand effect: parameter changes are so severe that the unlearning itself may be detected. (Related to Goodhart-ing the unlearning metrics).

 

Overview of Evaluation Metrics

Some of the main metrics used for evaluation are described in this Survey of Machine Unlearning. In brackets I have added a comment on my evaluation for how useful this is in practice for interpretability/related techniques on language models.

  • Change in Accuracy
    • Compared to:
      • original model  (good)
      • retrained model (too expensive)
    • On metric:
      • % Top1
      • % Top10
      • Perplexity
      • Loss
      • Other 
    • Summarised by:
      • R.O.C. curve
      • Maximal Difference
      • Other?
    • Change in Behaviour
  • Time Cost of the Method
    • Unlearning Time vs (Re)training Time (*cheap, worth including)
  • Degree of Removal
    • Relearn Time (seems OK. somewhat expensive)
    • Anamnesis Index (AIN) (too expensive)
    • Completeness, compared to retrained model (too expensive)
  • Other Effects on the Model
    • Layer-wise Distance (not super useful, but cheap?)
    • Activation Distance (*possibly good)
    • Activation JS-Divergence (*possibly good)
    • Epistemic Uncertainty (seems too expensive? unsure)
    • Zero Retrain Forgetting (ZRF) Score (*seems ok?)
  • Data Privacy Related:
    • Membership Inference (unsure, seems use-case dependent)
    • Model Inversion Attack (*not really a benchmark, but can be useful)

 

Detailed View on Each Metric

We note that many of the techniques here involve re-training a model exclusively on the retained tasks. This, in most cases, will likely be too expensive to compute for most people when it comes to large language models. 

Change in Accuracy

How good is the model at making predictions? It should stay equal on the "retained" dataset, but get worse at the "unlearned" and "test" datasets. Note that this section could likely be expaned on much further.

  • Compared to:
    • original model  (good)
    • retrained model (too expensive)
  • On metric:
    • % Top1
    • % Top10
    • Perplexity
    • Loss
    • Other 
  • Summarised by:
    • R.O.C. curve
    • Maximal Difference

There are a lot of other "accuracy" metrics one could use, or more task-specific metrics. For example, one could use 

One can look at this paper I have written to get an example of some of the metrics I have tried for assessing drops in accuracy. These are somewhat dependent on the specific metric, but In particular we use the metrics:

  • Draw the curve at different levels of pruning, comparing % drop in topk accuracy for Retained and Unlearned tasks.
  • Draw the curve for perplexity at different levels of pruning, showing perplexity as a multiple of initial perplexity for Retained and Unlearned tasks.
  • Get the maximal difference between drop in % top1 in retained task and unlearned task

There are, however, many metrics one could use, which makes it difficult to coordinate on which metrics to evaluate your technique on. In addition, some accuracy benchmarks are more qualitative than direct next-token prediction (eg: "write an answer").

Change in Behaviour

One should also consider, there are other ways one could measure behaviour that may not be accurately described by the word "accuracy". This could include things such as "toxicity" and "bias", or "refusing harmful requests" and "conforming to instructions". While some papers do try to look at these, there is a wide variety of ways of modelling model behaviour and performance that is not particularly well described in most Machine Unlearning literature, that would likely be useful to understand for a broader search into interpretability metrics. 

Time Cost

Evaluation: You should probably be including this anyway

How long does your technique take? How does this compare to training the original model? This seems like you should be collecting this information anyway, so you should probably include it in your report. 

  • Unlearning Time: How long does it take for your technique to unlearn?
  • Original Training Time: How long does training the original model take?
  • Retraining Time: How long does/would it take to retrain the model? (I would likely not include this, as retraining takes a long time). 

Degree of Removal

How well do you remove the unlearned task from the model? Does the model still possess most of the machinery required to do the task, and you just removed a tiny piece that is inconsequential in the grand scheme of things? Here are a couple of metrics that try to measure this:

Relearn Time

Evaluation: Seems OK. Can be expensive

How long does it take to relearn the unlearned skill? Depending on what you are doing (eg: removing a very small amount of knowledge for a specific fact, or removing a large variety of general capabilities), this may or may not be Feasible. 

If you are making relatively small changes to your language model, I suspect it should be relatively inexpensive by doing a Quantilised + Low-Rank Adapter (QLoRA) finetuning of your model. If so, it would be valuable to see how long it would take to do this. Otherwise, If this is not possible, or you cannot afford to do such experiments, then that seems OK.

Ideally, you would be able to compare this to a model that has been retrained, though retraining a model without the unlearned task is usually prohibitively expensive.

Anamnesis Index (AIN)

Evaluation: too expensive (requires retraining)

Compare the "relearn time"  on the forgotten task, for the unlearned model , and the retrained model ), to be within  performance of the original model .

Ideally AIN should be close to 1. If relearning takes longer on the unlearned model, then you likely have Goodhart-ed the unlearning task.

This metric doesn't seem particularly useful for interpretability, and is also quite expensive to run.

Completeness (compared to retrained model)

Evaluation: too expensive (involves retrained model)

Check if the model fully forgets removed data. Is the model after unlearning is like a new model trained without the forgotten data?

Calculate the overlap (using Jaccard distance) between the outputs of the unlearned and retrained models. Ensures no traces of forgotten data impact the model's predictions.

Other Effects on the Model

How much does the unlearning affect parts of the model? How affected is the model on retained tasks? on the unlearned tasks? Here are some metrics that people try to use sometimes:

Layer-wise Distance

Evaluation: seems not super useful, but cheap, so maybe worth including?

This is a relatively simple metric: How different are the weights of the original model compared to the unlearned model? the retrained model? a randomly initialised model?

I somewhat doubt the practical value of this for interpretability, and don't really understand the point of this metric. I guess if the difference between the original model and the unlearned model is larger than the difference between the original model and the retrained model, I would be somewhat suspicious of the unlearning method.

Activation Distance

Evaluation: Seems possibly good.

Originally for this metric, you would get the average L2-distance between the unlearned model and retrained model’s predicted probabilities on the forget set to try to evaluate "indistinguishability". In this case, using a retrained model is too expensive.

However, I think one could build a variation of this metric that compares:

  • original model vs. unlearned model vs. randomly initialised model
  • retained tasks vs unlearned tasks vs random inputs 

Then one could try to see how much difference there is between these different activations. See also section on ZRF score.

Activation JS-Divergence

Evaluation: seems good? unsure

Similar to Activation distance, but instead of L2-Distance, you get the Jensen-Shannon Divergence. Same arguments as above.

Epistemic Uncertainty

Evaluation: seems too expensive? unsure

Measures how much information about a dataset the model has learned. Expensive to compute. My understanding of the method for computation:

Step 1: Compute Fisher Information Matrix (FIM):

  • w = model weights
  • D = Dataset, consisting of: x = input, y = output
  • p(y|x; w) = probability of observing output y given input x for a model with parameters w.

Step 2: Compute Influence Function: 

Step 3: Compute Efficacy:

  1. If , then 
    1. The more the model parameters are influenced by the dataset, the more there is left to learn, and so, the lower the efficacy score.
        
  2. If , then 
    1. An infinite efficacy score implies no influence of the dataset on the model parameters, or essentially, the model wouldn't learn anything new.

My understanding is that the efficacy measures how much the model has already learned about the data. If you were to measure it for base model vs unlearned model on retained vs unlearned tasks, then you could have a baseline for comparison.

If one has to follow the above method, it seems prohibitively expensive for large models, though there may be ways to get approximately the same information with a less expensive method.

Zero Retrain Forgetting (ZRF) Score

Evaluation: seems good?

If we use a gradient-based machine unlearning method, we don't want to explicitly train the model to give the opposite answer, or to give a strangely uniform output prediction. This metric kinda checks for this. We get outputs for the unlearned model, and a randomly initialised model, and calculate the Jensen-Shannon divergence between the two, and calculate:

Where:

  • x_i = unlearned/forgetting sample
  • n_f = number of forgetting samples
  •  = Unlearned Model
  •  = Randomly Initialised Model ("Incompetent Teacher") 
  •  = Jensen-Shannon Divergence

Then we can evaluate:

  • ZRF  1: The model behaves like a randomly initialised model on forgot samples.
  • ZRF  0: The model exhibits some pattern on forget samples 

If the ZRF score is close to 1, that is good. One caveat is that in some cases (i.e: when you explicitly train to mimic a randomly initialised model), being too close to 1 could be a sign of Goodhart-ing the unlearning criteria (since models trained on task A, but not on task B, might still have better-than-random performance on task B). Overall, it seems like a useful metric for understanding how much information loss compared to original activations there is.

Note that these metrics seem use-case dependent and not super useful in general, as they are particularly interested in the question of data privacy.

Membership Inference Attack

Evaluationunsure, seems use-case dependent.

In general, Membership Inference Attacks ask: “Was this data point part of the training data?” There are too many methods to list here, and they often work under different assumptions.  This might be useful for trying to understand tampering in a model, and may be useful for interpretability, but I am unsure how easily this could be converted into a benchmark.

One example given in the context of Machine Unlearning and privacy preservation is: “Given the Original Model and the Unlearned Model, can you infer what was unlearned?”. While interesting, I am unsure how applicable this specific example is for machine unlearning.

Possible use in interpretability: if one was ablating a part responsible for a task, then membership inference techniques could be useful to understand how completely the ablation removes that capability on that task.

Some things to keep in mind:

  • Many (but not all) membership inference attack methods require having multiple models to compare against, but there are others that seem to work well when you only have a single model also.
  • Sometimes in publicly available models, there could be additional training to defend against model inference attacks, and may not always be explicitly stated. This may make interpretability more difficult.

Model Inversion Attack (i.e: “Feature Visualisation”)

Evaluation: not really a benchmark, but can be useful

I think the main idea here is to try to reconstruct the input given the output, using the unlearned model. The approach is basically the same as “Feature Visualisation”, and is already often used to better understand models. This could be useful for trying to get qualitative feedback on the approach. The main drawbacks are that it doesn’t apply as well to text-only language models, and is also not really a quantitative benchmark

 

Discussion

Existing Evaluations in Interpretability

There are many ways of trying to do interpretability, and many ways of assessing how good your interpretation is. I have listed a couple of the main ones here. While each of these can be a good initial metric, I think there is a lot of potential for better evaluating interpretability techniques. Often the metrics can be quite task-specific.

While I think the Machine Unlearning metrics can provide a rich source of information, how applicable they are is highly dependent on the exact technique you are looking at. I would expect more of these metrics to be much applicable to something like Sparse AutoEncoder research, and less applicable to something like ActAdd. However, I think having a better explicit list of metrics/benchmarks for Interpretability and implementations for running these benchmarks would be quite valuable.

Viewing Features

One method used in various cases is to directly try to have features that look interpretable, and seeing how strongly they activate on some input. Some examples include earlier work in “Feature Visualisation”, and later in “Taking Representations out of Superposition using Sparse Auto-Encoders” (OriginalScaled-Up) and linear-probe based techniques such as “Language Models Represent Space and Time” or “Discovering Latent Knowledge”.

However, it is unclear in some of these cases to what extent the component is solely responsible for the behaviour, as it may also be responsible for other tasks, or there may be other component that fulfil the same function. Here is where Machine Unlearning evaluations seem to be the most useful. By intervening on these components, and using a variety of the metrics above, one could better understand the effect of ablation of these components.

Direct Effect on Logits 

One of the most common metrics is directly looking at logits for a specific immediate next token prediction. This can be directly by running the model to the end and looking at the logits, or by inferring the direct effect on logits based on changes in a mid-layer (i.e: Logit Lens, or more recently, Tuned Lens). This can be useful, and provide tight feedback loops, but I think that having a larger range of metrics on the effect on accuracy and activations would be useful.

Looking at Text Generations

Another method that is not-quite-interpretability-related is looking at text generations. This can be seen in, for example, the ActAdd paper, where they make generations, and measure word frequencies. I think having more text generation metrics would be quite interesting, and is something I am actively looking into more.

 

Conclusion

I think there is a lot of room for better metrics in interpretability and model control. Some of these Machine Unlearning metrics seem like potentially useful (while some remain too expensive or not particularly relevant).

One metric that I think is somewhat lacking, is how changes might affect what longer-term generations look like. I am  working on a possible metric relevant to this here: [Post Coming Soon™], but I think there is potential for other work to be done as well.

Machine unlearning seems to be a possible direct way of evaluating interpretability methods. I am am interested in working on making an implementation to make it easier to run all of these different metrics, and would be excited for more work to be done in the direction of evaluating interpretability methods 

 

Note: If you think there are important metrics I left out, please comment below. I may update update the post to include it.

 

References

"Survey of Machine Unlearning" / "Awesome Machine Unlearning"

"Dissecting Language Models: Machine Unlearning via Selective Pruning"

"Can Bad Teaching Induce Forgetting? Unlearning in Deep Networks using an Incompetent Teacher"

"Feature Visualization"

"Sparse Autoencoders Find Highly Interpretable Directions in Language Models"

"Towards Monosemanticity: Decomposing Language Models With Dictionary Learning"

Language Models Represent Space and Time

"Discovering Latent Knowledge in Language Models Without Supervision"

"Interpreting GPT: the logit lens"

"Eliciting Latent Predictions from Transformers with the Tuned Lens"

"ActAdd: Steering Language Models without Optimization"

New to LessWrong?

New Comment
2 comments, sorted by Click to highlight new comments since: Today at 7:15 AM

I agree that this is an important frontier (and am doing a big project on this).

I believe these evaluations in unlearning miss a critical aspect: they benchmark on deleting i.i.d samples or a specific class, instead of adversarially manipulated/chosen distributions. This might fool us to believe unlearning methods work as shown in our paper both theoretically (Theorem 1) and empirically. The same failure mode holds for interpretability, which is a similar argument as the motivation to study across the whole distribution in the recent Copy Suppression paper.