Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.
This is a linkpost for https://arxiv.org/abs/2309.02390

This is a linkpost for our paper Explaining grokking through circuit efficiency, which provides a general theory explaining when and why grokking (aka delayed generalisation) occurs, and makes several interesting and novel predictions which we experimentally confirm (introduction copied below). You might also enjoy our explainer on X/Twitter.

Abstract

One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.

Introduction

When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbed grokking that drastically violates this expectation. The network first "memorises" the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question: why does the network's test performance improve dramatically upon continued training, having already achieved nearly perfect training performance?

Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss ("slingshots") (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically.

We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call "circuits" (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well () and one which memorises the training dataset (). The key insight is that when there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high "efficiency", that is, circuits that require less parameter norm to produce a given logit value.

Efficiency answers our question above: if  is more efficient than , gradient descent can reduce nearly perfect training loss even further by strengthening  while weakening  , which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1)  generalises well while  does not, (2)  is more efficient than  , and (3)  is learned more slowly than .

Since  generalises well, it automatically works for any new data points that are added to the training dataset, and so its efficiency should be independent of the size of the training dataset. In contrast,  must memorise any additional data points added to the training dataset, and so its efficiency should decrease as training dataset size increases. We validate these predictions by quantifying efficiencies for various dataset sizes for both  and .

This suggests that there exists a crossover point at which  becomes more efficient than , which we call the critical dataset size . By analysing dynamics at , we predict and demonstrate two new behaviours (Figure 1). In ungrokking, a model that has successfully grokked returns to poor test accuracy when further trained on a dataset much smaller than . In semi-grokking, we choose a dataset size where  and  are similarly efficient, leading to a phase transition but only to middling test accuracy.

We make the following contributions:

  1. We demonstrate the sufficiency of three ingredients for grokking through a constructed simulation (Section 3).
  2. By analysing dynamics at the "critical dataset size" implied by our theory, we predict two novel behaviours: semi-grokking and ungrokking (Section 4).
  3. We confirm our predictions through careful experiments, including demonstrating semi-grokking and ungrokking in practice (Section 5).
New Comment
11 comments, sorted by Click to highlight new comments since:
[-]AlgonΩ372

Which of these theories:

Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss ("slingshots") (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E).

can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.

Also, how does this theory explain other grokking related pheonmena e.g. Omni-Grok? And how do things change as you increase parameter count? Scale matters, and I am not sure whether things like ungrokking would vanish with scale as catastrophic forgetting did. Or those various inverse scaling phenomena.

[-]Rohin ShahΩ5124

Which of these theories [...] can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.

I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:

  1. Difficulty of representation learning: Shrugs at our prediction about  /  efficiencies, anti-predicts ungrokking (since in that case the representation has already been learned), shrugs at semi-grokking.
  2. Scale of parameters at initialisation: Shrugs at all of our predictions. If you interpret it as making a strong claim that scale of parameters at initialisation is the crucial thing (i.e. other things mostly don't matter) then it anti-predicts semi-grokking.
  3. Spikes in loss / slingshots: Shrugs at all of our predictions.
  4. Random walks among optimal solutions: Shrugs at our prediction about  /  efficiencies. I'm not sure what this theory says about what happens after you hit the generalising solution -- can you then randomly walk away from the generalising solution? If yes, then it predicts that if you train for a long enough time without changing the dataset, a grokked network will ungrok (false in our experiments, and we often trained for much longer than time to grok); if no then it anti-predicts ungrokking and semi-grokking.
  5. Simplicity of the generalising solution: This is our explanation. Our paper is basically an elaboration, formalization, and confirmation of Nanda et al's theory, as we allude to in the next sentence after the one you quoted.

how does this theory explain other grokking related pheonmena e.g. Omni-Grok?

My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.

Happy to speculate on other grokking phenomena as well (though I don't think there are many others?)

And how do things change as you increase parameter count?

We haven't investigated this, but I'd pretty strongly predict that there mostly aren't major qualitative changes. (The one exception is semi-grokking; there's a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)

I expect there would be quantitative changes (e.g. maybe the value of  changes, maybe the time taken to learn  changes). Sufficiently big changes in  might mean you don't see the phenomena on modular addition any more, but I'd still expect to see them in more complicated tasks that exhibit grokking.

I'd be interested in investigations that got into these quantitative questions (in addition to the above, there's also things like "quantitatively, how does the strength of weight decay affect the time for  to be learned?", and many more).

[-]AlgonΩ130

I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. 

Implictly, I thought if a you have a partial hypothesis of grokking, then if it shrugs at an grokking related phenomena it should be penalized. Unless by "shrugs" you mean the details of what the partial hypothesis says in this particular case are still being worked out. But in that case, confirming the partial hypothesis doesn't say anything yet about some phenomena is still useful info. I'm fairly sure this belief was what generated my question. 

This is mostly what happens with the other theories

Thank you for going through the theories and checking what they have to say. That was helpful to me. 

I'd be interested in investigations that got into these quantitative questions

Do you have any plans to do this? How much time do you think it would take? And do you have any predictions for what should happen in these cases?

[-]Rohin ShahΩ340

Unless by "shrugs" you mean the details of what the partial hypothesis says in this particular case are still being worked out.

Yes, that's what I mean.

I do agree that it's useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.

Do you have any plans to do this?

No, we're moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn't that unexpected, from the start we expected "science of deep learning" to be more hits-based, or to require significant progress before it actually became useful for practical proposals).

How much time do you think it would take?

Actually running the experiments should be pretty straightforward, I'd expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I'd still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like "under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases".

The hard part is then interpreting those results and turning them into something more generalizable -- including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses -- this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don't know how long it would take if you want to include that; it could be quite a while (e.g. months or years).

And do you have any predictions for what should happen in these cases?

Not really. I've learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.

For example, in our minimal model in Section 3 we have a hyperparameter  that determines how param norm and logits scale together -- initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).

(I say "seemed to falsify" because it's still possible that we're just failing to deal with confounders in some way, or measuring something that isn't exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels -- maybe there's a relevant difference between these.)

[-]LawrenceCΩ220

My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.

Higher weight norm means lower effective learning rate with Adam, no? In that paper they used a constant learning rate across weight norms, but Adam tries to normalize the gradients to be of size 1 per paramter, regardless of the size of the weights. So the weights change more slowly with larger initializations (especially since they constrain the weights to be of fixed norm by projecting after the Adam step). 

Sounds plausible, but why does this differentially impact the generalizing algorithm over the memorizing algorithm?

Perhaps under normal circumstances both are learned so fast that you just don't notice that one is slower than the other, and this slows both of them down enough that you can see the difference?

[-]AlgonΩ451

ungrokking, in which a network regresses from perfect to low test accuracy

Is this the same thing as catastrophic forgetting?

[-]Rohin ShahΩ450

From page 6 of the paper:

Ungrokking can be seen as a special case of catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990), where we can make much more precise predictions. First, since ungrokking should only be expected once , if we vary  we predict that there will be a sharp transition from very strong to near-random test accuracy (around ). Second, we predict that ungrokking would arise even if we only remove examples from the training dataset, whereas catastrophic forgetting typically involves training on new examples as well. Third, since  does not depend on weight decay, we predict the amount of “forgetting” (i.e. the test accuracy at convergence) also does not depend on weight decay.

(All of these predictions are then confirmed in the experimental section.)

[-]Rohin ShahΩ6112

I think that post has a lot of good ideas, e.g. the idea that generalizing circuits get reinforced by SGD more than memorizing circuits at least rhymes with what we claim is actually going on (that generalizing circuits are more efficient at producing strong logits with small param norm). We probably should have cited it, I forgot that it existed.

But it is ultimately a different take and one that I think ends up being wrong (e.g. I think it would struggle to explain semi-grokking).

I also think my early explanation, which that post compares to, is basically as good or better in hindsight, e.g.:

  1. My early explanation says that memorization produces less confident probabilities than generalization. Quintin's post explicitly calls this out as a difference, I continued to endorse my position in the comments. In hindsight my explanation was right and Quintin's was wrong, at least if you believe our new paper.
  2. My early explanation relies on the assumption that there are no "intermediate" circuits, only a pure memorization and pure generalization circuit that you can interpolate between. Again, this is called out as a difference by Quintin's post, and I continued to endorse my position in the comments. Again, I think in hindsight my explanation was right and Quintin's was wrong (though this is less clearly implied by our paper, and I could still imagine future evidence overturning that conclusion, though I really don't expect that to happen).
  3. On the other hand, my early explanation involves a random walk to the generalizing circuit, whereas in reality it develops smoothly over time. In hindsight, my explanation was wrong, and Quintin's is correct.

The LessWrong Review runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2024. The top fifty or so posts are featured prominently on the site throughout the year.

Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?