While lurking LessWrong, I read Apple's "The Super Weight in Large Language Models" paper and OpenAI's "Weight-sparse transformers have interpretable circuits" paper. My curiosity was simple, whether it is possible to bridge the core ideas derived from the two papers to explore a new direction, namely:
If I destroy the model's ability to generate responses by ablating a "superweight", can I then fix it with a tiny row-level patch?
I do not own a GPU so I tried testing this on my MacBook (I ran it overnight!). Rather than focusing on a toy model, I leveraged a real model (i.e OLMo-1B) following advice of @Neel Nanda to get messy with real weights.
I used AllenAI's OLMo-1B-0724-hf, running the model on my CPU with 16GB RAM. Inspired by the LLMSuperWeight repository, I nabbed the two OLMo-1B superweights featured there. More specifically, I chose and zeroed out model.layers.1.mlp.down_proj.weight[1764, 1710].
Trying out the ablated model on a Wikitext-2 slice yielded the following results:
OLMo Prompt & Outputs (Ablated)
Prompt: "Paris is in France. Tokyo is in".
Output: "Paris is in France. Tokyo is in the ocean. The ocean. The ocean. the. the...".
Prompt: "The capital of Canada is".
Output: "The capital of Canada is a lot of people. leaves of of of of...".
Crunching the numbers (i.e NLL, PPL, KL), I observed that Perplexity skyrocketed from 17.4 to 2884.3, and that the model spiralled towards pathological behaviour whilst showcasing unusual confabulations. More noticeably, when asked about Tokyo, the model claimed it is in the ocean on top of outputting "mar, mar, mar" many times in succession.
To introduce a fix for the broken model, I decided to freeze the entire model and introduce a single, trainable vector . I added this vector to the superweight's row 1764 of the down-projection matrix, like so:
After doing this, I trained the model utilizing the original model as a teacher and the broken model (plus patch) as a student on train[:2%] of the Wikitext-2-raw-v1 dataset, filtering by non-empty lines. I treated the token-shifted KL divergence between teacher and student logits as the loss. Severe compute constraints (i.e only CPU available) led to only two epochs of training with batch size 2 (thanks to my RAM) and 200 steps per epoch, totalling 400 optimization steps. This lead to a nice downward trend for the loss (see the top right graph below).
Surprisingly, it worked! Perplexity dropped drastically(i.e 2884.3 to 25.2) nearing the original model's values, and the Tokyo prompt was fixed, as seen in example outputs below. All in all, I observed approximately 93% recovery.
OLMo Prompt & Outputs (Patched)
My hypothesis that the patch would just relearn the original weight I deleted was wrong. As I came to discover, this was totally not the case. I observed that the cosine similarity between my patch and the original row was only around 0.13. In addition, I noticed the patch learned a completely new direction (i.e norm of about 3.0) to compensate for zero'd values. When I sparsified the patch (i.e keeping only Top-16 entries), performance degraded significantly, suggesting that the patch acts as a distributed circuit. This means it spreads the repair across several non-zero entries rather than via a singular super scalar.
Weird behaviour in the broken model was also a part of my further analysis by checking out the tokens themselves. They were all seemingly marine biology terms (i.e lobster, North Sea, etc), leading me to believe I had ablated the "water neuron" that the patch attempted to rebuild. This logic explains the weird "mar, marina, maritimes" etc outputted by the broken model.
This whole experiment was done independently on a CPU in 16 hours while applying to MATS. I am looking for Research Engineer roles as a recent graduate at present, if you or anyone you know is seeking a person who enjoys digging into model internals (in spite of no GPU), don't hesitate to reach out!
Code: https://github.com/sunmoonron/super-weight-circuit-patching