Is it accurate to summarize the headline result as follows?
(I don't know what Computational Mechanics or MSPs are so this could be totally off.)
EDIT: Looks like yes. From this post:
Part of what this all illustrates is that the fractal shape is kinda… baked into any Bayesian-ish system tracking the hidden state of the Markov model. So in some sense, it’s not very surprising to find it linearly embedded in activations of a residual stream; all that really means is that the probabilities for each hidden state are linearly represented in the residual stream.
"The structure of synchronization is, in general, richer than the world model itself. In this sense, LLMs learn more than a world model" given that I expect this is the statement that will catch a lot of people's attention.
Just in case this claim caught anyone else's attention, what they mean by this is that it contains:
• A model of the world
• A model of the agent's process for updating its belief about which state the world is in
[EDIT: I no longer endorse this response, see thread.]
(This comment is mainly for people other than the authors.)
If your reaction to this post is "hot damn, look at that graph", then I think you should probably dial back your excitement somewhat. IIUC the fractal structure is largely an artifact of how the data is visualized, which means the results visually look more striking than they really are.
It is still a cool piece of work, and the visuals are beautiful. The correct amount of excitement is greater than zero.
To me the consequences of this response were more valuable than the-post-without-this-response, since it led to the clarification by the post's author on a crucial point that wasn't clear in the post and reframed it substantially. And once that clarification arrived, this thread ceased being highly upvoted, which seems the opposite of the right thing to happen.
I no longer endorse this response
(So it's a case where value of content in hindsight disagrees with value of the consequences of its existence. Doesn't even imply there was originally an error, without the benefit of hindsight.)
Can you elaborate on how the fractal is an artifact of how the data is visualized?
From my perspective, the fractal is there because we chose this data generating structure precisely because it has this fractal pattern as it's Mixed State Presentation (ie. we chose it because then the ground truth would be a fractal, which felt like highly nontrivial structure to us, and thus a good falsifiable test that this framework is at all relevant for transformers. Also, yes, it is pretty :) ). The fractal is a natural consequence of that choice of data generating structure - it is what Computational Mechanics says is the geometric structure of synchronization for the HMM. That there is a linear 2d plane in the residual stream that when you project onto it you get that same fractal seems highly non-artifactual, and is what we were testing.
Though it should be said that an HMM with a fractal MSP is a quite generic choice. It's remarkably easy to get such fractal structures. If you randomly chose an HMM from the space of HMMs for a given number of states and vocab size, you will often get synchronizations structures with infinite transient states and fractals.
This isn't a proof of that previous claim, but here are some examples of fractal MSPs from https://arxiv.org/abs/2102.10487:
Responding in reverse order:
If there's literally a linear projection of the residual stream into two dimensions which directly produces that fractal, with no further processing/transformation in between "linear projection" and "fractal", then I would change my mind about the fractal structure being mostly an artifact of the visualization method.
There is literally a linear projection (well, we allow a constant offset actually, so affine) of the residual stream into two dimensions which directly produces that fractal. There's no distributions in the middle or anything. I suspect the offset is not necessary but I haven't checked ::adding to to-do list::
edit: the offset isn't necessary. There is literally a linear projection of the residual stream into 2D which directly produces the fractal.
But the "fractal-ness" is mostly an artifact of the MSP as a representation-method IIUC; the stochastic process itself is not especially "naturally fractal".
(As I said I don't know the details of the MSP very well; my intuition here is instead coming from some background knowledge of where fractals which look like those often come from, specifically chaos games.)
I'm not sure I'm following, but...
We're now working through understanding all the pieces of this, and we've calculated an MSP which doesn't quite look like the one in the post:
(Ignore the skew, David's still fiddling with the projection into 2D. The important noticeable part is the absence of "overlap" between the three copies of the main shape, compared to the fractal from the post.)
Specifically, each point in that visual corresponds to a distribution for some value of the observed symbols . The image itself is of the points on the probability simplex. From looking at a couple of Crutchfield papers, it sounds like that's what the MSP is supposed to be.
The update equations are:
with given by the transition probabilities, given by the observation probabilities, and a normalizer. We generate the image above by running initializing some random distribution , then iterating the equations and plotting each point.
Off the top of your head, any idea what might account for the mismatch (other than a bug in our code, which we're alread...
Everything looks right to me! This is the annoying problem that people forget to write the actual parameters they used in their work (sorry).
Try x=0.05, alpha=0.85. I've edited the footnote with this info as well.
The figures remind me of figures 3 and 4 from Meta-learning of Sequential Strategies, Ortega et al 2019, which also study how autoregressive models (RNNs) infer underlying structure. Could be a good reference to check out!
.
This is very cool! I’m excited to see where it goes :)
A couple questions (mostly me grappling with what the implications of this work might be):
Thanks!
Promoted to curated: Formalizing what it means for transformers to learn "the underlying world model" when engaging in next-token prediction tasks seems pretty useful, in that it's an abstraction that I see used all the time when discussing risks from models where the vast majority of the compute was spent in pre-training, where the details usually get handwaived. It seems useful to understand what exactly we mean by that in more detail.
I have not done a thorough review of this kind of work, but it seems to me that also others thought the basic ideas in the work hold up, and I thought reading this post gave me crisper abstractions to talk about this kind of stuff in the future.
transformer is only trained explicitly on next token prediction!
I find myself understanding language/multimodal transformer capabilities better when I think about the whole document (up to context length) as a mini-batch for calculating the gradient in transformer (pre-)training, so I imagine it is minimizing the document-global prediction error, it wasn't trained to optimize for just a single-next token accuracy...
There is evidence that transformers are not in fact even implicitly, internally, optimized for reducing global prediction error (except insofar as comp-mech says they must in order to do well on the task they are optimized for).
Do transformers "think ahead" during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at t that is then used in future forward passes t+τ. We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present in training result in the model computing features at t irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step t are already the same as those that would most benefit inference at time t+τ. We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis.
I think that paper is some evidence that there's typically no huge effect from internal activations being optimized for predicting future tokens (on natural language). But I don't think it's much (if any) evidence that this doesn't happen to some small extent or that it couldn't be a huge effect on certain other natural language tasks.
(In fact, I think the myopia gap is probably the more relevant number than the local myopia bonus, in which case I'd argue the paper actually shows a pretty non-trivial effect, kind of contrary to how the authors interpret it. But I haven't read the paper super closely.)
Also, sounds like you're aware of this, but I'd want to highlight more that the paper does demonstrate internal activations being optimized for predicting future tokens on synthetic data where this is necessary. So, arguably, the main question is to what extent natural language data incentivizes this rather than being specifically about what transformers can/tend to do.
In that sense, thinking of transformer internals as "trying to" minimize the loss on an entire document might be exactly the right intuition empirically (and the question is mainly how different that is from being myopic on a given dataset). Given that the internal states are optimized for this, that would also make sense theoretically IMO.
I have maybe a naive question. What information is needed to find the MSP image within the neural network? Do we have to know the HMM to begin with? Or could it be feasible someday to inspect a neural network, find something that looks like an MSP image, and infer the HMM from it?
I really enjoyed reading this post! It's quite well-written. Thanks for writing it.
The only critique is that I would have appreciated more details on how the linear regression parameters are trained and what exactly the projection is doing. John's thread is a bit clarifying on this.
One question: If you optimize the representation in the residual stream such that it corresponds to a particular chosen belief state, does the transformer than predict the next token as if in that belief state? I.e., does the transformer use the belief state for making predictions?
I struggled with the notation on the figures; this comment tries to clarify a few points for anyone else who may be confused by it.
I'm curious how much space is left after learning the MSP in the network. Does representing the MSP take up the full bandwidth of the model (even if it is represented inefficiently)? Could you maintain performance of the model by subtracting out the contributions of anything else that isn't part of the MSP?
This is extremely cool! Can you go into more detail about the step used to project the 64 dimensional residual stream to 3 dimensional space? Did you do a linear fit over a few test points and then used it on all the others?
This is really cool work!!
...In other experiments we've run (not presented here), the MSP is not well-represented in the final layer but is instead spread out amongst earlier layers. We think this occurs because in general there are groups of belief states that are degenerate in the sense that they have the same next-token distribution. In that case, the formalism presented in this post says that even though the distinction between those states must be represented in the transformers internal, the transformer is able to lose those distinctions for the purpose
We do this by performing standard linear regression from the residual stream activations (64 dimensional vectors) to the belief distributions (3 dimensional vectors) which associated with them in the MSP.
I don't understand how we go from this to the fractal. The linear probe gives us a single 2D point for every forward pass of the transformer, correct? How do we get the picture with many points in it? Is it by sampling from the transformer while reading the probe after every token and then putting all the points from that on one graph?
Is this result equiva...
We look in the final layer of the residual stream and find a linear 2D subspace where activations have a structure remarkably similar to that of our predicted fractal. We do this by performing standard linear regression from the residual stream activations (64 dimensional vectors) to the belief distributions (3 dimensional vectors) which associated with them in the MSP.
Naive technical question, but can I ask for a more detailed description of how you go from the activations in the residual stream to the map you have here? Or like, can someone point m...
I thought that the part about models needing to keep track of a more complicated mix-state presentation as opposed to just the world model is one of those technical insights that's blindingly obvious once someone points it out to you (i.e., the best type of insight :)). I love how the post starts out by describing the simple ZIR example to help us get a sense of what these mixed state presentations are like. Bravo!
Non exhaustive list of reasons one could be interested in computational mechanics: https://www.lesswrong.com/posts/GG2NFdgtxxjEssyiE/dalcy-s-shortform?commentId=DdnaLZmJwusPkGn96
Produced while being an affiliate at PIBBSS[1]. The work was done initially with funding from a Lightspeed Grant, and then continued while at PIBBSS. Work done in collaboration with @Paul Riechers, @Lucas Teixeira, @Alexander Gietelink Oldenziel, and Sarah Marzen. Paul was a MATS scholar during some portion of this work. Thanks to Paul, Lucas, Alexander, Sarah, and @Guillaume Corlouer for suggestions on this writeup.
Update May 24, 2024: See our manuscript based on this work
Introduction
What computational structure are we building into LLMs when we train them on next-token prediction? In this post we present evidence that this structure is given by the meta-dynamics of belief updating over hidden states of the data-generating process. We'll explain exactly what this means in the post. We are excited by these results because
Theoretical Framework
In this post we will operationalize training data as being generated by a Hidden Markov Model (HMM)[2]. An HMM has a set of hidden states and transitions between them. The transitions are labeled with a probability and a token that it emits. Here are some example HMMs and data they generate.
Consider the relation a transformer has to an HMM that produced the data it was trained on. This is general - any dataset consisting of sequences of tokens can be represented as having been generated from an HMM. Through the discussion of the theoretical framework, let's assume a simple HMM with the following structure, which we will call the Z1R process[3] (for "zero one random").
The Z1R process has 3 hidden states, S0,S1, and SR. Arrows of the form Sxa:p%−−−→Sy denote P(Sy,a|Sx)=p%, that the probability of moving to state Sy and emitting the token a, given that the process is in state Sx, is p%. In this way, taking transitions between the states stochastically generates binary strings of the form
...01R01R...
whereR
is a random 50/50 sample from {0
,1
}.The HMM structure is not directly given by the data it produces. Think of the difference between the list of strings this HMM emits (along with their probabilities) and the hidden structure itself[4]. Since the transformer only has access to the strings of emissions from this HMM, and not any information about the hidden states directly, if the transformer learns anything to do with the hidden structure, then it has to do the work of inferring it from the training data.
What we will show is that when they predict the next token well, transformers are doing even more computational work than inferring the hidden data generating process!
Do Transformers Learn a Model of the World?
One natural intuition would be that the transformer must represent the hidden structure of the data-generating process (ie the "world"[2]). In this case, this would mean the three hidden states and the transition probabilities between them.
This intuition often comes up (and is argued about) in discussions about what LLM's "really understand." For instance, Ilya Sutskever has said:
This type of intuition is natural, but it is not formal. Computational Mechanics is a formalism that was developed in order to study the limits of prediction in chaotic and other hard-to-predict systems, and has since expanded to a deep and rigorous theory of computational structure for any process. One of its many contributions is in providing a rigorous answer to what structures are necessary to perform optimal prediction. Interestingly, Computational Mechanics shows that prediction is substantially more complicated than generation. What this means is that we should expect a transformer trained to predict the next token well should have more structure than the data generating process!
The Structure of Belief State Updating
But what is that structure exactly?
Imagine you know, exactly, the structure of the HMM that produces
...01R...
data. You go to sleep, you wake up, and you see that the HMM has emitted a1
. What state is the HMM in now? It is possible to generate a1
both from taking the deterministic transition S11:100%−−−−−→SR or from taking the stochastic transition SR1:50%−−−−→S0. Since the deterministic transition is twice as likely as the 50% one, the best you can do is to have some belief distribution over the current states of the HMM, in the case P([S0,S1,SR])=[13,0,23][5].1
1
0
1...
If now you see another
1
emitted, so that in total you've seen11
, you can now use your previous belief about the HMM state (read: prior), and your knowledge of the HMM structure alongside the emission you just saw (read: likelihood), in order to generate a new belief state (read: posterior). An exercise for the reader: What is the equation for updating your belief state given a previous belief state, an observed token, and the transition matrix of the ground-truth HMM?[6] In this case, there is only one way for the HMM to generate11
, S11:100%−−−−−→SR1:50%−−−−→S0, so you know for certain that the HMM is now in state S0. From now on, whenever you see a new symbol, you will know exactly what state the HMM is in, and we say that you have synchronized to the HMM.In general, as you observe increasing amounts of data generated from the HMM, you can continually update your belief about the HMM state. Even in this simple example there is non-trivial structure in these belief updates. For instance, it is not always the case that seeing 2 emissions is enough to synchronize to the HMM. If instead of
11...
you saw10...
you still wouldn't be synchronized, since there are two different paths through the HMM that generate10
.The structure of belief-state updating is given by the Mixed-State Presentation.
The Mixed-State Presentation
Notice that just as the data-generating structure is an HMM - at a given moment the process is in a hidden state, then, given an emission, the process move to another hidden state - so to is your belief updating! You are in some belief state, then given an emission that you observe, you move to some other belief state.
The meta-dynamics of belief state updating are formally another HMM, where the hidden states are your belief states. This meta-structure is called the Mixed-State Presentation (MSP) in Computational Mechanics.
Note that the MSP has transitory states (in green above) that lead to a recurrent set of belief states that are isomorphic to the data-generating process - this always happens, though there might be infinite transitory states. Synchronization is the process of moving through the transitory states towards convergence to the data-generating process.
A lesson from Computational Mechanics is that in order to perform optimal prediction of the next token based on observing a finite-length history of tokens, one must implement the Mixed-State Presentation (MSP). That is to say, to predict the next token well one should know what state the data-generating process is in as best as possible, and to know what state the data-generating process is in, implement the MSP.
The MSP has a geometry associated with it, given by plotting the belief-state values on a simplex. In general, if our data generating process has N states, then probability distributions over those states will have N−1 degrees of freedom, and since all probabilities must be between 0 and 1, all possible belief distributions lie on an N−1 simplex. In the case of Z1R, that means a 2-simplex (i.e. a triangle). We can plot each of our possible belief states in this 2-simplex, as shown on the right below.
What we show in this post is that when we train a transformer to do next token prediction on data generated from the 3-state HMM, we can find a linear representation of the MSP geometry in the residual stream. This is surprising! Note that the points on the simplex, the belief states, are not the next token probabilities. In fact, multiple points here have literally the same next token predictions. In particular, in this case, η10, ηS, and η101, all have the same optimal next token predictions.
Another way to think about this claim is that transformers keep track of distinctions in anticipated distribution over the entire future, beyond distinctions in next token predictions, even though the transformer is only trained explicitly on next token prediction! That means the transformer is keeping track of extra information than what is necessary just for the local next token prediction.
Another way to think about our claim is that transformers perform two types of inference: one to infer the structure of the data-generating process, and another meta-inference to update it's internal beliefs over which state the data-generating process is in, given some history of finite data (ie the context window). This second type of inference can be thought of as the algorithmic or computational structure of synchronizing to the hidden structure of the data-generating process.
A final theoretical note about Computational Mechanics and the theory presented here: because Computational Mechanics is not contingent on the specifics of transformer architectures and is a well-developed first-principles framework, we can apply this framework to any optimal predictor, not just transformers[7].
Experiment and Results
Experimental Design
To repeat the question we are trying to answer:
To test our theoretical predictions, we designed an experiment with the following steps:
By controlling the structure of the training data using an HMM, we can make concrete, falsifiable predictions about the computational structure the transformer should implement during inference. Computational Mechanics, as presented in the "Theoretical Framework" section above, provides the framework for making these predictions based on the HMM's structure.
The specific HMM we chose has an MSP with an infinite fractal geometry, serving as a highly non-trivial prediction about what we should find in the transformer's residual stream activations if our theory is correct.
The Data-Generating Process and MSP
For this experiment we trained a transformer on data generated by a simple HMM, called the Mess3 Process, that has just 3 hidden states[8]. Moving between the 3 hidden states according to the emission probabilities on the edges generates strings over a 3-token vocabulary: {
A
,B
,C
}. The HMM for this data-generating process is given on the left of the figure below.Our approach allows us to make rigorous and testable predictions about the internal structures of transformers. In the case of this HMM, the theory (outlined above) says that transformers trained on this data should instantiate the computational structure associated with the fractal geometry shown on the right of the figure above. Every colored point in the simplex on the above right panel is a distinct belief state.
We chose the Mess3 HMM because it's MSP has an infinite fractal structure, and thus acts as a highly-nontrivial prediction about what geometry we should find in the residual stream.
The Results!
We train a transformer on data generated by the Mess3 HMM. We look in the final layer of the residual stream and find a linear 2D subspace where activations have a structure remarkably similar to that of our predicted fractal. We do this by performing standard linear regression from the residual stream activations (64 dimensional vectors) to the belief distributions (3 dimensional vectors) which associated with them in the MSP.
We can also look at how this structure emerges over training, which shows (1) that the structure we find is not trivial[9] since it doesn’t exist in detail early in training, and (2) the step-wise refinement of the transformers activations to the fractal structure we predict.
A movie of this process is shown below. Because we used Stochastic Gradient Descent for training, the 2D projection of the activations wiggles, even after training has converged. In this wiggling you can see that fractal structures remain intact.
Limitations and Next Steps
Limitations
Next Steps
PIBBSS is hiring! I wholeheartedly recommend them as an organization.
One way to conceptualize this is to think of "the world" as having some hidden structure (initially unknown to you), that emits observables. Our task is then to take sequences of observables and infer the hidden structure of the world - maybe in the service of optimal future prediction, but also maybe just because figuring out how the world works is inherently interesting. Inside of us, we have a "world model" that serves as the internal structure that let's us "understand" the hidden structure of the world. The term world model is contentious and nothing in this post depends on that concept much. However, one motivation for this work is to formalize and make concrete statements about peoples intuitions and arguments regarding neural networks and world models - which are often handwavy and ill-defined.
Technically speaking, the term process refers to a probability distribution over infinite strings of tokens, while a presentation refers to a particular HMM that produces strings according to the probability distribution. A process has an infinite number of presentations.
Any HMM defines a probability distribution over infinite sequences of the emissions.
Our initial belief distribution, in this particular case, is the uniform distribution over the 3 states of the data generating process. However this is not always the case. In general the initial belief distribution is given by the stationary distribution of the data generating HMM.
You can find the answer in section IV of this paper by @Paul Riechers.
There is work in Computational Mechanics that studies non-optimal or near-optimal prediction, and the tradeoffs one incurs when relaxing optimality. This is likely relevant to neural networks in practice. See Marzen and Crutchfield 2021 and Marzen and Crutchfield 2014.
This process is called the mess3 process, and was defined in a paper by Sarah Marzen and James Crutchfield. In the work presented we use x=0.05, alpha=0.85.
We've also run another control where we retain the ground truth fractal structure but shuffle which inputs corresponds to which points in the simplex (you can think of this as shuffling the colors in the ground truth plot). In this case when we run our regression we get that every residual stream activation is mapped to the center point of the simplex, which is the center of mass of all the points.