I hadn't seen the latter, thanks for sharing!
Agreed, it seems less elegant, But one guy on huggingface did a rough plot the cross correlation, and it seems to show that the directions changes with layer https://huggingface.co/posts/Undi95/318385306588047#663744f79522541bd971c919. Although perhaps we are missing something.
Idk. This shows that if you wanted to optimally get rid of refusal, you might want to do this. But, really, you want to balance between refusal and not damaging the model. Probably many layers are just kinda irrelevant for refusal. Though really this argues that we're both wrong, and the most surgical intervention is deleting the direction from key layers only.
Thanks! I'm personally skeptical of ablating a separate direction per block, it feels less surgical than a single direction everywhere, and we show that a single direction works fine for LLAMA3 8B and 70B
The transformer lens library does not have a save feature :(
Note that you can just do torch.save(FILE_PATH, model.state_dict()) as with any PyTorch model.
Thanks for making these! How expensive is it?
Makes sense! Sounds like a fairly good fit
It just seems intuitively like a natural fit: Everyone in mech interp needs to inspect models. This tool makes it easier to inspect models.
Another way of framing it: Try to write your paper in such a way that a mech interp researcher reading it says "huh, I want to go and use this library for my research". Eg give examples of things that were previously hard that are now easy.
Looks relevant to me on a skim! I'd probably want to see some arguments in the submission for why this is useful tooling for mech interp people specifically (though being useful to non mech interp people too is a bonus!)
That's awesome, and insanely fast! Thanks so much, I really appreciate it
Nope to both of those, though I think both could be interesting directions!
Nah I think it's pretty sketchy. I personally prefer mean ablation, especially for residual stream SAEs where zero ablation is super damaging. But even there I agree. Compute efficiency hit would be nice, though it's a pain to get the scaling laws precise enough
For our paper this is irrelevant though IMO because we're comparing gated and normal SAEs, and I think this is just scaling by a constant? It's at least monotonic in CE loss degradation
Nnsight, pyvene, inseq, torchlens are other libraries coming to mind that it would be good to discuss in a related work. Also penzai in JAX