Rohin Shah

Research Scientist at DeepMind. Creator of the Alignment Newsletter.


Value Learning
Alignment Newsletter

Wiki Contributions


Good point on the rotational symmetry, that makes sense now.

I still think that this assumption is fairly realistic because in practice, most pairs of unrelated features would co-occur only very rarely, and I expect the winner-take-all dynamic to dominate most of the time. But I agree that it would be nice to quantify this and test it out.

Agreed that's a plausible hypothesis. I mostly wish that in this toy model you had a hyperparameter for the frequency of co-occurrence of features, and identified how it affects the rate of incidental polysemanticity.

I think I agree with all of that (with the caveat that it's been months and I only briefly skimmed the past context, so further thinking is unusually likely to change my mind).

My guess is that this result is very sensitive to the design of the training dataset:

the input/output data pairs are  for , where  is the  basis vector.

In particular, I think it is likely very sensitive to the implicit assumption that feature i and feature j never co-occur on a single input. I'd be interested to see experiments where each feature is turned on with some (not too small) probability, independently of all other features, similarly to the original toy models setting. This would result in some inputs where feature i and j are on simultaneously. My prediction would be that polysemanticity goes down very significantly (probably to zero if the probabilities are high enough and the training is done for long enough).

I also don't understand why L1 regularization on activations is necessary to show incidental polysemanticity given your setup. Even if you remove the L1 regularization on activations, it is still the case that "benign collisions" impose no cost on the model, since feature i and feature j are never simultaneously present in a given input. So if you do get a benign collision, what causes it to go away? Overall my expectation would be that without the L1 regularization on activations (and with the training dataset as described in this post), you'd get a complicated mess where every neuron is highly polysemantic, i.e. even more polysemanticity than described in this post. Why is that wrong?

One piece missing here, insofar as current methods don't get to 99% of loss recovered, is repeatedly drilling into the residual until they do get to 99%.

When you do that using existing methods, you lose the sparsity (e.g. for circuit finding you have to include a large fraction of the model to get to 99% loss recovered).

It's of course possible that this is because the methods are bad, though my guess is that at the 99% standard this is reflecting non-sparsity / messiness in the territory (and so isn't going to go away with better methods). I do expect we can improve; we're very far from the 99% standard. But the way we improve won't be by "drilling into the residual"; that has been tried and is insufficient. EDIT: Possibly by "drill into the residual" you mean "understand why the methods don't work and then improve them" -- if so I agree with that but also think this is what mech interp researchers want to do.

(Why am I still optimistic about interpretability? I'm not convinced that the 99% standard is required for downstream impact -- though I am pretty pessimistic about the "enumerative safety" story of impact, basically for the same reasons as Buck and Ryan afaict.)

This seems like exactly what mech interp is doing? Circuit finding is all about finding sparse subgraphs. It continues to work with large models, when trying to explain a piece of the behavior of the large model. SAE stands for sparse autoencoder: the whole point is to find the basis in which you get sparsity. I feel like a lot of mech interp has been almost entirely organized around the principle of modularity / sparsity, and the main challenge is that it's hard (you don't get to 99% of loss recovered, even on pieces of behavior, while still being meaningfully sparse).

This has been stated many times before (I believe I heard it in Chris Olah’s 80k episode first) but worth reiterating.

The reference I like best is

Unless by "shrugs" you mean the details of what the partial hypothesis says in this particular case are still being worked out.

Yes, that's what I mean.

I do agree that it's useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.

Do you have any plans to do this?

No, we're moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn't that unexpected, from the start we expected "science of deep learning" to be more hits-based, or to require significant progress before it actually became useful for practical proposals).

How much time do you think it would take?

Actually running the experiments should be pretty straightforward, I'd expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I'd still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like "under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases".

The hard part is then interpreting those results and turning them into something more generalizable -- including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses -- this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don't know how long it would take if you want to include that; it could be quite a while (e.g. months or years).

And do you have any predictions for what should happen in these cases?

Not really. I've learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.

For example, in our minimal model in Section 3 we have a hyperparameter  that determines how param norm and logits scale together -- initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).

(I say "seemed to falsify" because it's still possible that we're just failing to deal with confounders in some way, or measuring something that isn't exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels -- maybe there's a relevant difference between these.)

Which of these theories [...] can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.

I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:

  1. Difficulty of representation learning: Shrugs at our prediction about  /  efficiencies, anti-predicts ungrokking (since in that case the representation has already been learned), shrugs at semi-grokking.
  2. Scale of parameters at initialisation: Shrugs at all of our predictions. If you interpret it as making a strong claim that scale of parameters at initialisation is the crucial thing (i.e. other things mostly don't matter) then it anti-predicts semi-grokking.
  3. Spikes in loss / slingshots: Shrugs at all of our predictions.
  4. Random walks among optimal solutions: Shrugs at our prediction about  /  efficiencies. I'm not sure what this theory says about what happens after you hit the generalising solution -- can you then randomly walk away from the generalising solution? If yes, then it predicts that if you train for a long enough time without changing the dataset, a grokked network will ungrok (false in our experiments, and we often trained for much longer than time to grok); if no then it anti-predicts ungrokking and semi-grokking.
  5. Simplicity of the generalising solution: This is our explanation. Our paper is basically an elaboration, formalization, and confirmation of Nanda et al's theory, as we allude to in the next sentence after the one you quoted.

how does this theory explain other grokking related pheonmena e.g. Omni-Grok?

My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.

Happy to speculate on other grokking phenomena as well (though I don't think there are many others?)

And how do things change as you increase parameter count?

We haven't investigated this, but I'd pretty strongly predict that there mostly aren't major qualitative changes. (The one exception is semi-grokking; there's a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)

I expect there would be quantitative changes (e.g. maybe the value of  changes, maybe the time taken to learn  changes). Sufficiently big changes in  might mean you don't see the phenomena on modular addition any more, but I'd still expect to see them in more complicated tasks that exhibit grokking.

I'd be interested in investigations that got into these quantitative questions (in addition to the above, there's also things like "quantitatively, how does the strength of weight decay affect the time for  to be learned?", and many more).

From page 6 of the paper:

Ungrokking can be seen as a special case of catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990), where we can make much more precise predictions. First, since ungrokking should only be expected once , if we vary  we predict that there will be a sharp transition from very strong to near-random test accuracy (around ). Second, we predict that ungrokking would arise even if we only remove examples from the training dataset, whereas catastrophic forgetting typically involves training on new examples as well. Third, since  does not depend on weight decay, we predict the amount of “forgetting” (i.e. the test accuracy at convergence) also does not depend on weight decay.

(All of these predictions are then confirmed in the experimental section.)

Load More