Stefan Heimersheim. Research Scientist at Apollo Research, Mechanistic Interpretability.
Paper link: https://arxiv.org/abs/2407.20311
(I have neither watched the video nor read the paper yet, just in case someone else was looking for the non-video version)
Thanks! I'll edit it
[…] no reason to be concentrated in any one spot of the network (whether activation-space or weight-space). So studying weights and activations is pretty doomed.
I find myself really confused by this argument. Shards (or anything) do not need to be “concentrated in one spot” for studying them to make sense?
As Neel and Lucius say, you might study SAE latents or abstractions built on the weights, no one requires (or assumes) than things are concentrated in one spot.
Or to make another analogy, one can study neuroscience even though things are not concentrated in individual cells or atoms.
If we still disagree it’d help me if you clarified how the “So […]” part of your argument follows
Edit: The “the real thinking happens in the scaffolding” is a reasonable argument (and current mech interp doesn’t address this) but that’s a different argument (and just means we understand individual forward passes with mech interp).
Even after reading this (2 weeks ago), I today couldn't manage to find the comment link and manually scrolled down. I later noticed it (at the bottom left) but it's so far away from everything else. I think putting it somewhere at the top near the rest of the UI would be much easier for me
I would like the following subscription: All posts with certain tags, e.g. all [AI] posts or all [Interpretability (ML & AI)] posts.
I just noticed (and enabled) a “subscribe” feature in the page for the tag, it says “Get notifications when posts are added to this tag.” — I’m unsure if those are emails, but assuming they are, my problem is solved. I never noticed this option before.
And here's the code to do it with replacing the LayerNorms with identities completely:
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
# Undo my hacky LayerNorm removal
for block in model.transformer.h:
block.ln_1.weight.data = block.ln_1.weight.data / 1e6
block.ln_1.eps = 1e-5
block.ln_2.weight.data = block.ln_2.weight.data / 1e6
block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5
# Properly replace LayerNorms by Identities
class HookedTransformerNoLN(HookedTransformer):
def removeLN(self):
for i in range(len(self.blocks)):
self.blocks[i].ln1 = torch.nn.Identity()
self.blocks[i].ln2 = torch.nn.Identity()
self.ln_final = torch.nn.Identity()
hooked_model = HookedTransformerNoLN.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
hooked_model.removeLN()
hooked_model.cfg.normalization_type = None
prompt = torch.tensor([1,2,3,4], device="cpu")
logits = hooked_model(prompt)
print(logits.shape)
print(logits[0, 0, :10])
Here's a quick snipped to load the model into TransformerLens!
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=False, center_unembed=False).to("cpu")
# Kill the LayerNorms because TransformerLens overwrites eps
for block in hooked_model.blocks:
block.ln1.eps = 1e12
block.ln2.eps = 1e12
hooked_model.ln_final.eps = 1e12
# Make sure the outputs are the same
prompt = torch.tensor([1,2,3,4], device="cpu")
logits = hooked_model(prompt)
logits2 = model(prompt).logits
print(logits.shape, logits2.shape)
print(logits[0, 0, :10])
print(logits2[0, :10])
I really like the investigation into properties of SAE features, especially the angle of testing whether SAE features have particular properties than other (random) directions don't have!
Random directions as a baseline: Based on my experience here I expect random directions to be a weak baseline. For example the covariance matrix of model activations (or SAE features) is very non-uniform. I'd second @Hoagy's suggestion of linear combination of SAE features, or direction towards other model activations as I used here.
Ablation vs functional FT-LLC: I found the comparison between your LLC measure (weights before the feature), and the ablation effect (effect of this feature on the output) interesting, and I liked that you give some theories, both very interesting! Do you think @jake_mendel's error correction theory is related to these in any way?
I like this idea! I'd love to see checks of this on the SOTA models which tend to have lots of layers (thanks @Joseph Miller for running the GPT2 experiment already!).
I notice this line of argument would also imply that the embedding information can only be accessed up to a certain layer, after which it will be washed out by the high-norm outputs of layers. (And the same for early MLP layers which are rumoured to act as extended embeddings in some models.) -- this seems unexpected.
Additionally, they would be further evidence (but not conclusive[2]) towards hypotheses Residual Networks Behave Like Ensembles of Relatively Shallow Networks.
I have the opposite expectation: Effective layer horizons enforce a lower bound on the number of modules involved in a path. Consider the shallow path
If the effective layer horizon is 25, then this path cannot work because the output of MLP10 gets lost. In fact, no path with less than 3 modules is possible because there would always be a gap > 25.
Only a less-shallow paths would manage to influence the output of the model
This too seems counterintuitive, not sure what to make of this.
Thanks for the nice writeup! I'm confused about why you can get away without interpretation of what the model components are:
In cases where we worry that our model learned a human-simulator / camera-simulator rather than actually predicting whether the diamond exists, wouldn't circuit discovery simply give us the human-simulator circuit? (And thus causal scrubbing doesn't save us.) I'm thinking in particular of cases where the human-simulator is easier to learn than the intended solution.
Of course if you had good interpretability, a way to realise whether your explanation is the human simulator is to look for suspicious human-simulator-related features. I would like to get away without interpretation, but it's not clear to me that this works.