Epistemic status: not an Interpretability researcher, but has followed the seen closely.
So, it makes sense to me that Probes should outperform SAEs: probes are trained directly to maximize an interpretable metric, while SAEs on the other hand are trained to maximize reconstruction loss, and then are interpreted. But training SAEs is nice because this is an unsupervised problem, meaning that you don't need to create a dataset to find directions for each concept like you do with probes.
How can we get the best of both worlds? Well just train SAEs on an objective which directly maximizes human interpretability of the feature!
Illustration of the training objectives of SAEs and Probes. The third design is the proposition of the blog-post: train an SAE on an objective which includes, like probes, an interpretable metric.
How SAEs works? They are trained to reconstruct the activations of the LM with a sparcity constraint. Then, on hopes that using the right SAE architecture and a good enough sparsity ratio will give interpretable features. In practice, it does! Features are pretty good, but we want something better.
L(SAE)λ=Ex[∥x−SAE(x)∥2]+λ∥SAE∥2
How Probes works? They are trained to predict if the token predicted by the LM is related to a chosen concept. If one want the feature representing a concept A, one needs to construct and label a dataset where the feature A is sometime present and sometime not.
L(Probew)=Ex,y[log(1+e−⟨x|w⟩y)]
Proposal. On top of the reconstruction and sparsity losses, train SAEs with a third loss given by doing RLAIF on how interpretable, simple, and causal features are. By "doing RLAIF", the procedure I have in mind is:
The SAE constructs vectors that achieve a low reconstruction loss while being sparse.
These SAE vectors can be interpreted by a LM based on how interpretable, simple, and causal the features are.
For Interpretability, one can use Anthropic's metric of correlation between the activation of the direction and the presence of the concept in the text.
For Causality, the same metric can be used when injecting the direction into a sentence, and seeing the relate concept appear.
Finally, use a method akin to RLAIF to compute gradients from these qualitative evaluations, and backpropagate to change the directions.[1]
Prediction.RLAIF-SAE should find direction which are interpretable like Probes, but with the unsupervised strength of the SAE. I predict that asking for simple features should solve problems like feature splitting, the existence of meta-SAEs, and feature atomicity.
Problems. If this should be implemented, the first problem I would imagine is the scaling. Computing the RLAIF part of the loss will be costly as one need to use validation corpus and LMs for grades. I don't know how this process could be optimized well, but one solution could be to first train an SAE normally, and finish by having it being fine-tuned doing RLAIF, to align the concept with human concepts, just like we do with LMs.
Working on this project: I am not a Mech-Interp researcher and don't have the skills to execute this project on my own. I would be happy to collaborate with people, so consider reaching out to me (Léo Dana, I will be at NeurIPS during the Mech Interp workshop). If you just want to try this out on your own, feel free, yet I'll appreciate being notified to know if the project works or not.
This is the heart of the problem: how to backpropagate "human interpretability" to learnt the direction. This is put under the carpet since I believe it to be solve by the existence of RLHF and RLAIF for fine tuning LMs. If doing is not as easy as it seems to me, I will be happy to discuss it in the comment section and debate on what is the core difference.
Epistemic status: not an Interpretability researcher, but has followed the seen closely.
So, it makes sense to me that Probes should outperform SAEs: probes are trained directly to maximize an interpretable metric, while SAEs on the other hand are trained to maximize reconstruction loss, and then are interpreted. But training SAEs is nice because this is an unsupervised problem, meaning that you don't need to create a dataset to find directions for each concept like you do with probes.
How can we get the best of both worlds? Well just train SAEs on an objective which directly maximizes human interpretability of the feature!
How SAEs works? They are trained to reconstruct the activations of the LM with a sparcity constraint. Then, on hopes that using the right SAE architecture and a good enough sparsity ratio will give interpretable features. In practice, it does! Features are pretty good, but we want something better.
L(SAE)λ=Ex[∥x−SAE(x)∥2]+λ∥SAE∥2How Probes works? They are trained to predict if the token predicted by the LM is related to a chosen concept. If one want the feature representing a concept A, one needs to construct and label a dataset where the feature A is sometime present and sometime not.
L(Probew)=Ex,y[log(1+e−⟨x|w⟩y)]Proposal. On top of the reconstruction and sparsity losses, train SAEs with a third loss given by doing RLAIF on how interpretable, simple, and causal features are. By "doing RLAIF", the procedure I have in mind is:
- The SAE constructs vectors that achieve a low reconstruction loss while being sparse.
- These SAE vectors can be interpreted by a LM based on how interpretable, simple, and causal the features are.
- For Interpretability, one can use Anthropic's metric of correlation between the activation of the direction and the presence of the concept in the text.
- For Causality, the same metric can be used when injecting the direction into a sentence, and seeing the relate concept appear.
- Finally, use a method akin to RLAIF to compute gradients from these qualitative evaluations, and backpropagate to change the directions.[1]
L(SAE)λ,γ=Ex[∥x−SAE(x)∥2]+λ∥SAE∥2+γE[RLAIF(SAE(x))]Prediction. RLAIF-SAE should find direction which are interpretable like Probes, but with the unsupervised strength of the SAE. I predict that asking for simple features should solve problems like feature splitting, the existence of meta-SAEs, and feature atomicity.
Problems. If this should be implemented, the first problem I would imagine is the scaling. Computing the RLAIF part of the loss will be costly as one need to use validation corpus and LMs for grades. I don't know how this process could be optimized well, but one solution could be to first train an SAE normally, and finish by having it being fine-tuned doing RLAIF, to align the concept with human concepts, just like we do with LMs.
Working on this project: I am not a Mech-Interp researcher and don't have the skills to execute this project on my own. I would be happy to collaborate with people, so consider reaching out to me (Léo Dana, I will be at NeurIPS during the Mech Interp workshop). If you just want to try this out on your own, feel free, yet I'll appreciate being notified to know if the project works or not.
This is the heart of the problem: how to backpropagate "human interpretability" to learnt the direction. This is put under the carpet since I believe it to be solve by the existence of RLHF and RLAIF for fine tuning LMs. If doing is not as easy as it seems to me, I will be happy to discuss it in the comment section and debate on what is the core difference.