One major obstacle to interpretability is that complicated neural nets don't tell you where or how they're representing important concepts, and methods to find these representations are imperfect.
This problem is less present in simple neural networks, so one natural idea is to initialize a complicated neural net from a much simpler, interpretable neural net and hope that this induces better interpretability in the complicated neural net without damaging its capacity.
I did a test that could have ruled this out - specifically, I tried to check whether even the representation is persistent under this initialization scheme, because if it's not, there's not much hope that the circuits are. I found a small effect in the predicted direction, but couldn't rule out other explanations and so am pretty unsure about whether the underlying mechanism is favorable to interpretability.
Hypothesis
We usually think of transfer learning as a way of taking a big powerful model and making it very good at a specific type of task, but we might also want to take a weak model and use it as a starting point to train a bigger, more powerful model, as in Net2Net knowledge transfer;[1]essentially, take your small model, do some math to find a way to add parameters to it without changing what it does, then train those new parameters in conjunction with the old ones, typically at a lower learning rate. But this doesn't help with interpretability - the big powerful model is already hard to understand, so we've traded a hard problem for a hard problem. What can we do?
Say I want to train a model on some task I know to be pretty difficult. Say I have a guess for an instrumentally useful, easier, but still nontrivial subtask. I know, because I've learned the Bitter Lesson[2], that I shouldn't put a loss term anywhere in my model for this subtask - this will hurt performance in the long run. But what if I train a small model on the subtask, embed that small model into a large model somehow, and train the large model on the main task? We usually think of transfer learning as helping us specialize generalist networks, but there's no reason it can't work the other way around.
The effect, we hope, is this: the smaller network has developed circuits that are useful for understanding the domain at hand, so subnetworks that include the smaller network are much more likely to be good at the task at hand. What we overwrote was junk, and we replaced it with something that's at least plausibly not junk. Usually this should make the model better than it would be with random initialization, even if the subtask is not perfectly instrumental.
What might this get us? In terms of capabilities, we might get faster convergence (this is basically just saying that transfer learning works) and mildly better performance at convergence (the original lottery ticket hypothesis paper[3]finds evidence that better initialization can induce better long-term performance.) We're spending compute training the smaller network, though, and on average we're probably better off putting all of that compute into the main model rather than doing some sort of matryoshka scheme, so we shouldn't expect to unlearn the Bitter Lesson with this approach.
In terms of interpretability, we can hope for more. Imagine, for example, training a small text transformer to perform sentiment analysis, then embedding that transformer into a larger text model for next token prediction. For combinatorial reasons, the model is likely to build circuits that factor through the circuits we've just given it - training builds circuits out of things that already somewhat resemble circuits, and having small parts that are guaranteed to resemble circuits makes this significantly easier. For proximity reasons, the large model is now more likely to put its own sentiment analysis right where the embedding ends. After all, it's already using those circuits and they're already well-adapted to that subtask! There are many things that could go wrong in this story, but my hypothesis is that they don't need to go wrong, and at least in some cases we can influence a large model's representation of a concept we care about using this approach.
Unfortunately finding circuits is hard, so this is an experiment designed to avoid doing the hard thing if it's unnecessary. Say I train the smaller model to do the task of the larger model, but with some easy-to-compute thing linearly encoded in its representation space somewhere. If I embed that model and train without the linear encoding constraint, then if this approach can work, I should expect some amount of linear encoding of that thing to persist in the residual stream at that point. If this doesn't happen, then either the large model completely ignored the smaller model or it repurposed the smaller model's circuits for an entirely different task, and either way we can't hope for any interpretability gains. On the other hand, if there is a persistent difference in the linear encoding of the relevant thing, more work on interpretability proper is justified.
Experiment
The domain is the combinatorial game Domineering[4]on a 16×16 board. I'm using Domineering for three reasons: one, I already had a fast implementation lying around, so I saved myself some work. Two, the game isn't that complicated and I wanted to write this up on a relatively-short timeframe so I can include it on my applications for summer research programs. (I had initially planned to do this+other AI interpretability stuff over the summer on my own, but decided recently that I'd get better faster, and probably produce better work, if I applied to things.) Three, it was easy to think of an auxiliary task which is plausibly useful, easy to compute, and seems to promote particular ways of structuring the representation which we might have some hope at detecting.
The Auxiliary Task
We divide the board into a 4×4 grid of 4×4 sectors. For each sector, the auxiliary target is the difference between the number of legal vertical moves and the number of legal horizontal moves in that sector (where a move is "in a sector" if the top-left square it covers is in that sector). The small network is trained to predict these sector values alongside the main value and policy objectives. The large network is not trained on this task - we only probe it to see whether the representation persists from the embedding.
Data
Data was generated by self-play from a weak model, trained to predict the value of a given position, with 1-ply lookahead as the search. I bootstrapped this model with some randomly-generated games. This is not a particularly high-quality dataset, it was just what I could generate for the board size I wanted with the amount of time and compute I was willing to dedicate to this project. It's possible the results would change with higher-quality data.
The Embedding
Given a trained small network and a randomly-initialized large network, we copy the small network into layers 0, 1, 2 of the large network. The tricky part is the fresh components, which consist of new heads and MLP neurons in each of those layers.
To fix this, we set the relevant output weights to 0. Specifically, for fresh attention heads we zero WO, and for fresh MLP neurons we zero the corresponding columns of Wout. The input weights (WQ, WK, WV, Win) stay random.
Why does this work? The residual stream through the embedded layers is now exactly the same as in the small network - the fresh components contribute nothing. LayerNorm sees the statistics it was trained on. The copied circuits receive the inputs they expect. But gradients still flow through the zeros, so the fresh components can wake up and learn during training.
It's plausible that there are ways to make this work even without zeroing the Wout matrices, but this would disrupt lots of circuits. It's also plausible that we could embed somewhere other than at the front of the model, but this would mess with learned embeddings, so I just did the thing that I knew wouldn't cause extra problems. Among things I thought of and had confidence in, this was the minimal set of changes to the big network's initialization.
What We're Testing
We train 5 model types across 3 random seeds:
Small aux: trained with sector loss
Small noaux: trained without sector loss
Large baseline: random init, no embedding
Large embed(aux): Small+aux embedded into large network
Large embed(noaux): Small-noaux embedded into large network
Large models are never trained with the sector loss. We measure validation loss curves and probe accuracy (R2 of a ridge probe predicting sector targets from CLS activations at each layer).
The key question: at layer 2 (the last embedded layer), does the sector representation persist in Large+embed(aux) even without direct supervision? My guess is that the network should route computation through the inherited circuits, and so should the learned representation should have some sort of compatibility with the sector representation. This does not mean that the model will actually use the sector representation as-is, and I don't think we have reason to expect a causal difference along these lines.
Loss curves on training data and seed-matched quick samples of the validation data. On the validation chart, Xs mark loss values computed from the full validation set.R^2 values for a ridge probe at layer 2 trained to extract the sector difference. The transparent lines show values from individual training runs, while opaque lines show the average.
I was careful about data leakage, so the games in the training set and the games in the test set are completely different, with each game getting a random opening to prevent resampling issues. It looks like the model generalizes fairly well, and I was careful about quick sampling, so models from the same seed were tested on the same positions at the same point in training. The probe here is a ridge probe at α=1 - this choice of α was not optimized but does not seem to matter.
What can we see from these results?
The first chart tells us that embedding a trained subnetwork makes the large network better faster. This shouldn't be too surprising - one good proxy for model strength is the FLOP count used to train it, and models with an embedded submodule just have more computation baked into them, so unless this method of embedding is extraordinarily wasteful, this is predictable.
The second chart shows pretty consistent order effects: the embedded aux model explains more of the variance in sector labels at layer 2 and the embedded no-aux model explains less compared to the baseline model. This makes sense under our hypothesis: even at loss-equivalent (and even compute) points in training, the representation used by the embedded model is noticeably more compatible with the auxiliary task! On the other hand, the gap shrinks throughout training and the R2 values are low - I ran ridge regressions on the models after the full training run and found that, on average, the baseline models explain around 28% of the sector count variance at layer 2 while the embedded auxiliary models explain around 33%. That is to say, neither model learns a representation that's strongly compatible with the task, even though the embedded model's representation necessarily is.
Did we actually induce fundamentally different representations, or is the gap just leftover from initialization inertia? That is, should we expect the gap in R2 values at this layer to decay to 0? Well . . .
A power law fits the decay fine, performs well on the first half of the data, and doesn't predict a persistent gap. But its distribution of guesses for the true gap value is really weird - centered at 0, but containing values as low as -0.2 in its 95% confidence interval? Power law + offset is a tricky model to fit because there's significant parameter interference.An exponential also fits the decay fine, performs well on the second half of the data, and predicts a persistent gap. But isn't it well-known that, on basically any decay problem, an exponential will predict that progress stops where data stops? To me this fit looks better, and the errors technically confirm this, but it's close.Power law models are better at predicting the data based on the first 20% of training steps, exponentials are better at predicting it based on the first 60%. The crossover point is roughly a 50% data prefix. Note that the data are just noisier in the last few steps, especially in relative terms, so a low average error on the last 40% of data is arguably more impressive than a low average error on the last 60%, since the former doesn't benefit from predicting the "easiest" datapoints.
This question is hard to answer robustly. The data are inherently noisy and different plausible models give different predictions about long-term behavior (most relevantly, power law+offset and exponential+offset disagree about whether the offset is different from 0.) I tried lots of things to fix this but ultimately could not convince myself that I had a robust way of estimating the gap after more training - the plots above reflect my confusion. My guess is that the gap will not train away and will settle somewhat north of 0.04 with my data and training scheme, which is what the bootstrapping scheme I came up with predicts while modeling the gap as a single exponential with an offset, but this should only be taken as a guess. If this doesn't happen my expectation is that the gap will decay to nothing, making this result much less interesting. I would be surprised to see an in-between result.
Remaining Questions
Does the representation gap actually persist? The most straightforward way to test this is to just throw more compute at the problem, and I plan to do this at some point.
What's the causal relationship here? Phrased another way, what representations did the models actually learn and why is one more compatible with the sector task than the other (while still not being especially compatible)? Similarly, can we track what happened to previously-identified circuits from the small model?
How do approaches like this behave with different auxiliary concepts? My guess would be that highly instrumental concepts exhibit bigger and more persistent gaps, and moreover, that we get better improvements on the loss value when the concept is more useful, although this second effect is probably subtle.
Does this work on language models? There's a lot of work already on finding primitive concepts in language models, so maybe it's easier to choose a particularly "good" auxiliary target in that domain.
How does this scale? Lottery ticket intuitions say that as scale increases and the task gets harder, the small model should make a noticeable difference even as it takes up smaller and smaller fractions of the parameter space.
How does embedding depth matter? If the auxiliary task is useful but it naturally lives deeper in the optimal computation, then embedding the small model in the later layers of the large model might perform better than embedding it right at the beginning
How much of the smaller model do we actually need to embed? If it had six layers, could we embed the middle four? I'm thinking of Paul Bach-y-Rita's famous work on neuroplasticity,[5]which I interpret as suggesting that certain computational structures (in his case the visual cortex) are especially well-suited to processing certain kinds of data (in his case 3D information), even when filtered through different modalities (in his case tactile vs. visual perception).
TL;DR
Hypothesis
We usually think of transfer learning as a way of taking a big powerful model and making it very good at a specific type of task, but we might also want to take a weak model and use it as a starting point to train a bigger, more powerful model, as in Net2Net knowledge transfer;[1]essentially, take your small model, do some math to find a way to add parameters to it without changing what it does, then train those new parameters in conjunction with the old ones, typically at a lower learning rate. But this doesn't help with interpretability - the big powerful model is already hard to understand, so we've traded a hard problem for a hard problem. What can we do?
Say I want to train a model on some task I know to be pretty difficult. Say I have a guess for an instrumentally useful, easier, but still nontrivial subtask. I know, because I've learned the Bitter Lesson[2], that I shouldn't put a loss term anywhere in my model for this subtask - this will hurt performance in the long run. But what if I train a small model on the subtask, embed that small model into a large model somehow, and train the large model on the main task? We usually think of transfer learning as helping us specialize generalist networks, but there's no reason it can't work the other way around.
The effect, we hope, is this: the smaller network has developed circuits that are useful for understanding the domain at hand, so subnetworks that include the smaller network are much more likely to be good at the task at hand. What we overwrote was junk, and we replaced it with something that's at least plausibly not junk. Usually this should make the model better than it would be with random initialization, even if the subtask is not perfectly instrumental.
What might this get us? In terms of capabilities, we might get faster convergence (this is basically just saying that transfer learning works) and mildly better performance at convergence (the original lottery ticket hypothesis paper[3]finds evidence that better initialization can induce better long-term performance.) We're spending compute training the smaller network, though, and on average we're probably better off putting all of that compute into the main model rather than doing some sort of matryoshka scheme, so we shouldn't expect to unlearn the Bitter Lesson with this approach.
In terms of interpretability, we can hope for more. Imagine, for example, training a small text transformer to perform sentiment analysis, then embedding that transformer into a larger text model for next token prediction. For combinatorial reasons, the model is likely to build circuits that factor through the circuits we've just given it - training builds circuits out of things that already somewhat resemble circuits, and having small parts that are guaranteed to resemble circuits makes this significantly easier. For proximity reasons, the large model is now more likely to put its own sentiment analysis right where the embedding ends. After all, it's already using those circuits and they're already well-adapted to that subtask! There are many things that could go wrong in this story, but my hypothesis is that they don't need to go wrong, and at least in some cases we can influence a large model's representation of a concept we care about using this approach.
Unfortunately finding circuits is hard, so this is an experiment designed to avoid doing the hard thing if it's unnecessary. Say I train the smaller model to do the task of the larger model, but with some easy-to-compute thing linearly encoded in its representation space somewhere. If I embed that model and train without the linear encoding constraint, then if this approach can work, I should expect some amount of linear encoding of that thing to persist in the residual stream at that point. If this doesn't happen, then either the large model completely ignored the smaller model or it repurposed the smaller model's circuits for an entirely different task, and either way we can't hope for any interpretability gains. On the other hand, if there is a persistent difference in the linear encoding of the relevant thing, more work on interpretability proper is justified.
Experiment
The domain is the combinatorial game Domineering[4]on a 16×16 board. I'm using Domineering for three reasons: one, I already had a fast implementation lying around, so I saved myself some work. Two, the game isn't that complicated and I wanted to write this up on a relatively-short timeframe so I can include it on my applications for summer research programs. (I had initially planned to do this+other AI interpretability stuff over the summer on my own, but decided recently that I'd get better faster, and probably produce better work, if I applied to things.) Three, it was easy to think of an auxiliary task which is plausibly useful, easy to compute, and seems to promote particular ways of structuring the representation which we might have some hope at detecting.
The Auxiliary Task
We divide the board into a 4×4 grid of 4×4 sectors. For each sector, the auxiliary target is the difference between the number of legal vertical moves and the number of legal horizontal moves in that sector (where a move is "in a sector" if the top-left square it covers is in that sector). The small network is trained to predict these sector values alongside the main value and policy objectives. The large network is not trained on this task - we only probe it to see whether the representation persists from the embedding.
Data
Data was generated by self-play from a weak model, trained to predict the value of a given position, with 1-ply lookahead as the search. I bootstrapped this model with some randomly-generated games. This is not a particularly high-quality dataset, it was just what I could generate for the board size I wanted with the amount of time and compute I was willing to dedicate to this project. It's possible the results would change with higher-quality data.
The Embedding
Given a trained small network and a randomly-initialized large network, we copy the small network into layers 0, 1, 2 of the large network. The tricky part is the fresh components, which consist of new heads and MLP neurons in each of those layers.
To fix this, we set the relevant output weights to 0. Specifically, for fresh attention heads we zero WO, and for fresh MLP neurons we zero the corresponding columns of Wout. The input weights (WQ, WK, WV, Win) stay random.
Why does this work? The residual stream through the embedded layers is now exactly the same as in the small network - the fresh components contribute nothing. LayerNorm sees the statistics it was trained on. The copied circuits receive the inputs they expect. But gradients still flow through the zeros, so the fresh components can wake up and learn during training.
It's plausible that there are ways to make this work even without zeroing the Wout matrices, but this would disrupt lots of circuits. It's also plausible that we could embed somewhere other than at the front of the model, but this would mess with learned embeddings, so I just did the thing that I knew wouldn't cause extra problems. Among things I thought of and had confidence in, this was the minimal set of changes to the big network's initialization.
What We're Testing
We train 5 model types across 3 random seeds:
Large models are never trained with the sector loss. We measure validation loss curves and probe accuracy (R2 of a ridge probe predicting sector targets from CLS activations at each layer).
The key question: at layer 2 (the last embedded layer), does the sector representation persist in Large+embed(aux) even without direct supervision? My guess is that the network should route computation through the inherited circuits, and so should the learned representation should have some sort of compatibility with the sector representation. This does not mean that the model will actually use the sector representation as-is, and I don't think we have reason to expect a causal difference along these lines.
Code
Code can be found at https://github.com/speck2993/domineering_embedding_project.
Results
I was careful about data leakage, so the games in the training set and the games in the test set are completely different, with each game getting a random opening to prevent resampling issues. It looks like the model generalizes fairly well, and I was careful about quick sampling, so models from the same seed were tested on the same positions at the same point in training. The probe here is a ridge probe at α=1 - this choice of α was not optimized but does not seem to matter.
What can we see from these results?
The first chart tells us that embedding a trained subnetwork makes the large network better faster. This shouldn't be too surprising - one good proxy for model strength is the FLOP count used to train it, and models with an embedded submodule just have more computation baked into them, so unless this method of embedding is extraordinarily wasteful, this is predictable.
The second chart shows pretty consistent order effects: the embedded aux model explains more of the variance in sector labels at layer 2 and the embedded no-aux model explains less compared to the baseline model. This makes sense under our hypothesis: even at loss-equivalent (and even compute) points in training, the representation used by the embedded model is noticeably more compatible with the auxiliary task! On the other hand, the gap shrinks throughout training and the R2 values are low - I ran ridge regressions on the models after the full training run and found that, on average, the baseline models explain around 28% of the sector count variance at layer 2 while the embedded auxiliary models explain around 33%. That is to say, neither model learns a representation that's strongly compatible with the task, even though the embedded model's representation necessarily is.
Did we actually induce fundamentally different representations, or is the gap just leftover from initialization inertia? That is, should we expect the gap in R2 values at this layer to decay to 0? Well . . .
This question is hard to answer robustly. The data are inherently noisy and different plausible models give different predictions about long-term behavior (most relevantly, power law+offset and exponential+offset disagree about whether the offset is different from 0.) I tried lots of things to fix this but ultimately could not convince myself that I had a robust way of estimating the gap after more training - the plots above reflect my confusion. My guess is that the gap will not train away and will settle somewhat north of 0.04 with my data and training scheme, which is what the bootstrapping scheme I came up with predicts while modeling the gap as a single exponential with an offset, but this should only be taken as a guess. If this doesn't happen my expectation is that the gap will decay to nothing, making this result much less interesting. I would be surprised to see an in-between result.
Remaining Questions
--
Net2Net: Accelerating Learning via Knowledge Transfer - Chen, Goodfellow, Shlens (ICLR 2016) ↩︎
The Bitter Lesson - Rich Sutton (2019) ↩︎
The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks - Frankle, Carbin (2018) ↩︎
Domineering - Wikipedia article on the game ↩︎
Vision substitution by tactile image projection - Bach-y-Rita et al. (1969) ↩︎