Epistemic status: not an Interpretability researcher, but has followed the seen closely.
So, it makes sense to me that Probes should outperform SAEs: probes are trained directly to maximize an interpretable metric, while SAEs on the other hand are trained to maximize reconstruction loss, and then are interpreted. But training SAEs is nice because this is an unsupervised problem, meaning that you don't need to create a dataset to find directions for each concept like you do with probes.
How can we get the best of both worlds? Well just train SAEs on an objective which directly maximizes human interpretability of the feature!
How SAEs works? They are trained to reconstruct the activations of the LM with a sparcity constraint. Then, on hopes that using the right SAE architecture and a good enough sparsity ratio will give interpretable features. In practice, it does! Features are pretty good, but we want something better.
How Probes works? They are trained to predict if the token predicted by the LM is related to a chosen concept. If one want the feature representing a concept A, one needs to construct and label a dataset where the feature A is sometime present and sometime not.
Proposal. On top of the reconstruction and sparsity losses, train SAEs with a third loss given by doing RLAIF on how interpretable, simple, and causal features are. By "doing RLAIF", the procedure I have in mind is:
Prediction. RLAIF-SAE should find direction which are interpretable like Probes, but with the unsupervised strength of the SAE. I predict that asking for simple features should solve problems like feature splitting, the existence of meta-SAEs, and feature atomicity.
Problems. If this should be implemented, the first problem I would imagine is the scaling. Computing the RLAIF part of the loss will be costly as one need to use validation corpus and LMs for grades. I don't know how this process could be optimized well, but one solution could be to first train an SAE normally, and finish by having it being fine-tuned doing RLAIF, to align the concept with human concepts, just like we do with LMs.
Working on this project: I am not a Mech-Interp researcher and don't have the skills to execute this project on my own. I would be happy to collaborate with people, so consider reaching out to me (Léo Dana, I will be at NeurIPS during the Mech Interp workshop). If you just want to try this out on your own, feel free, yet I'll appreciate being notified to know if the project works or not.
This is the heart of the problem: how to backpropagate "human interpretability" to learnt the direction. This is put under the carpet since I believe it to be solve by the existence of RLHF and RLAIF for fine tuning LMs. If doing is not as easy as it seems to me, I will be happy to discuss it in the comment section and debate on what is the core difference.