I wonder if this is relevant to the SGTM paper Anthropic put out recently? Could this be done to reverse the ablation they did? That is, because this method seems to not rely as much on old information, which Anthropic asserts they destroyed. Separately, I wonder what happens when you ablate and patch serially several times? Hundreds of times?
Also, the nature of the patch here strongly reminds me of neuroplasticity, and how when a region of the brain is damaged, other regions adjust to pick up the slack.
Jenna, thank you for commenting and sharing that paper. I read through it and it is quite closely related to my solo work (albeit in opposing directions). Indeed if the old information is not utilized by the patch, then Anthropic's SGTM safety measures might be less permanent than they might think. My experiment suggests the new circuit is orthogonal/distributed rather than restoring the old weight. I hypothesize what you're hinting is a new Red Team attack vector: instead of healing the ablated damage(which SGTM defends against), an attacker could utilize this method to graft a cheap, sparse patch to bypass the damage entirely.
As for your latter curiosity, I believe the default case would be total model collapse once the model runs out of lazy/spare neurons to repurpose. Of course, I could be wrong and instead a hyper-robust model could be created, following your intuition. I also really like your neuroplasticity analogy! It is very fitting given the model does not fix dead tissue, it reroutes the function to new areas of the brain.
P.S. Unfortunately, I am without a GPU cluster, so perhaps you can carry forth the torch :)
Motivation
While lurking LessWrong, I read Apple's "The Super Weight in Large Language Models" paper and OpenAI's "Weight-sparse transformers have interpretable circuits" paper. My curiosity was simple, whether it is possible to bridge the core ideas derived from the two papers to explore a new direction, namely:
If I destroy the model's ability to generate responses by ablating a "superweight", can I then fix it with a tiny row-level patch?
I do not own a GPU so I tried testing this on my MacBook (I ran it overnight!). Rather than focusing on a toy model, I leveraged a real model (i.e OLMo-1B) following advice of @Neel Nanda to get messy with real weights.
Setup
I used AllenAI's OLMo-1B-0724-hf, running the model on my CPU with 16GB RAM. Inspired by the LLMSuperWeight repository, I nabbed the two OLMo-1B superweights featured there. More specifically, I chose and zeroed out model.layers.1.mlp.down_proj.weight[1764, 1710].
Initial Findings (Validating Apple's Results)
Trying out the ablated model on a Wikitext-2 slice yielded the following results:
OLMo Prompt & Outputs (Ablated)
Prompt: "Paris is in France. Tokyo is in".
Output: "Paris is in France. Tokyo is in the ocean. The ocean. The ocean. the. the...".
Prompt: "The capital of Canada is".
Output: "The capital of Canada is a lot of people. leaves of of of of...".
Crunching the numbers (i.e NLL, PPL, KL), I observed that Perplexity skyrocketed from 17.4 to 2884.3, and that the model spiralled towards pathological behaviour whilst showcasing unusual confabulations. More noticeably, when asked about Tokyo, the model claimed it is in the ocean on top of outputting "mar, mar, mar" many times in succession.
The Patch
To introduce a fix for the broken model, I decided to freeze the entire model and introduce a single, trainable vector Δrow. I added this vector to the superweight's row 1764 of the down-projection matrix, like so:
W′=Wbroken+erow⊗Δrow
After doing this, I trained the model utilizing the original model as a teacher and the broken model (plus patch) as a student on train[:2%] of the Wikitext-2-raw-v1 dataset, filtering by non-empty lines. I treated the token-shifted KL divergence between teacher and student logits as the loss. Severe compute constraints (i.e only CPU available) led to only two epochs of training with batch size 2 (thanks to my RAM) and 200 steps per epoch, totalling 400 optimization steps. This lead to a nice downward trend for the loss (see the top right graph below).
Surprisingly, it worked! Perplexity dropped drastically(i.e 2884.3 to 25.2) nearing the original model's values, and the Tokyo prompt was fixed, as seen in example outputs below. All in all, I observed approximately 93% recovery.
OLMo Prompt & Outputs (Patched)
Gaining Insights
My hypothesis that the patch would just relearn the original weight I deleted was wrong. As I came to discover, this was totally not the case. I observed that the cosine similarity between my patch and the original row was only around 0.13. In addition, I noticed the patch learned a completely new direction (i.e norm of about 3.0) to compensate for zero'd values. When I sparsified the patch (i.e keeping only Top-16 entries), performance degraded significantly, suggesting that the patch acts as a distributed circuit. This means it spreads the repair across several non-zero entries rather than via a singular super scalar.
Weird behaviour in the broken model was also a part of my further analysis by checking out the tokens themselves. They were all seemingly marine biology terms (i.e lobster, North Sea, etc), leading me to believe I had ablated the "water neuron" that the patch attempted to rebuild. This logic explains the weird "mar, marina, maritimes" etc outputted by the broken model.
Conclusion
This whole experiment was done independently on a CPU in 16 hours while applying to MATS. I am looking for Research Engineer roles as a recent graduate at present, if you or anyone you know is seeking a person who enjoys digging into model internals (in spite of no GPU), don't hesitate to reach out!
Code: https://github.com/sunmoonron/super-weight-circuit-patching