CRG
Message
134
4
Sid Black*, Lee Sharkey*, Leo Grinsztajn, Eric Winsor, Dan Braun, Jacob Merizian, Kip Parker, Carlos Ramón Guevara, Beren Millidge, Gabriel Alfour, Connor Leahy
*equal contribution
Research from Conjecture.
This post benefited from feedback from many staff at Conjecture including Adam Shimi, Nicholas Kees Dupuis, Dan Clothiaux, Kyle McDonell. Additionally, the post also benefited from inputs from Jessica Cooper, Eliezer Yudkowsky, Neel Nanda, Andrei Alexandru, Ethan Perez, Jan Hendrik Kirchner, Chris Olah, Nelson Elhage, David Lindner, Evan R Murphy, Tom McGrath, Martin Wattenberg, Johannes Treutlein, Spencer Becker-Kahn, Leo Gao, John Wentworth, and Paul Christiano and from discussions with many other colleagues working on interpretability.
Mechanistic interpretability aims to explain what a neural network has learned at a nuts-and-bolts level. What are the fundamental primitives of neural network representations? What basic objects should we...
This is a great approach imo. I've tried something similar in transformers using the singular vectors of the embedding matrix (the d_model x d_model matrix) to rotate the matrices connected to the residual stream. This seemed to induce sparsity in the weights close to the first layer with decreasing effect moving deeper into the model. Tried this with the clip VIT-B and GPT-J, with the effect being a lot weaker in GPT-J. Also, some of the singular vectors of the embeddings were easily interpretable, with the top component being related to raw token frequency and interesting directions in GPT-J, (religion - technology) (positive - negative valence), and the top components of CLIP being color and frequency filters.
WD is not really about regularisation nowadays, so it's not surprising that it helps at all model sizes. Layernorm in transformers makes WD affect mostly the effective LR of the weights. (Except the final linear, the absolute scale of the weights doesn't matter, since you have a final LN), and so the actual effect of wd is keeping the update/weight ratio biger over training. (In fact, you can substitute WD in normed nets for an exponentially increasing LR schedule).
Yeah, it's not really clear how to apply that specific kind of data pruning (straightforward for an image classifier) to the case of causally modelling text tokens in full context windows or any other dense task like that.
The layernorm does in fact have parameters, two d_model size scale and shift parameters in each one. This adds 2xd_model parameters per block and an extra 2xd_model for the final layernorm at the unembedding.
LN(x) = (x-mean(x))/std(x) * scale + shift