This post is sort of meant to provide an explanation of the core ideas of a new preprint on the early detection of phase transitions in deep learning. The preprint could be cleaned up a bit, but I was very excited to share it so decided to share it in its current state. This post explains the core idea of the paper and why we figured this was an important direction.
Introduction
Nuclear Engineering
When I was in my final year of high school, I went through a short phase where I wanted to be a nuclear engineer. There were two main reasons for this; first being that I thought it would be cool to work on developing fusion reactors, and second being that even if I couldn't do that, I did live actually pretty close to one so at least I would have job opportunities! While my interest in this was rather short-lived, it did make me appreciate that a huge amount of nuclear engineering is making nuclear power plants not fail catastrophically.
At the heart of a fission plant sits the reactor pressure vessel. This is where the nuclear reaction actually occurs, and is thus the primary safety barrier. Furthermore, they can't really be replaced and thus are the object that defines the operational lifetime of the plant. The lifetime of an RPV is determined by embrittlement.
What is embrittlement? Inside the reactor, the RPV is constantly bombarded by fast neutrons which can knock atoms out of their lattice positions which create point defects, which then trigger microstructural changes that accumulate over years. Like most things, the properties of steels depends rather highly on their temperature. In the RPV, you want the walls to be ductile so they deform under stress instead of cracking. This requires however that the system sit above a certain temperature called the ductile–brittle transition temperature (DBTT), as above there the metal is more amenable to bending.
However, over many years the small microstructural changes raise this temperature, meaning that at the effective operating temperature the RPV keeps it closer to the brittle behaviour than the ductile behaviour. The RPV experiences significant thermal and mechanical stresses, and the more brittle it becomes, the more likely it is to experience a catastrophic failure (large cracks developing).
So how do they avoid this becoming a problem? Well, first they predict the expected lifetime of the reactor from from radiation exposure models. It would be rather irresponsible however to just do this computation let it ride. No, instead what they do is continuously monitor the reactor to validate their models and detect unexpected changes in the behaviour of the system.
How do they do this? A modern way is using acoustic emission monitoring. Consider many tiny microscopic events in a material: a crack tip advances by a tiny increment, a dislocation slips, a grain boundary shears, a new phase nucleates. Each of these events involves a sudden, localized release of stored elastic energy. These events produce transient elastic waves (stress waves) that propagate outward through the material at the speed of sound in that medium, which can be picked up by probing sensors placed strategically throughout the material. If you have multiple sensors, you can locate the source by measuring the difference in arrival times at different sensors. If a wave arrives at sensor A 20 microseconds before sensor B, and you know the wave speed, you can calculate the source position. With enough sensors you get 2D or 3D localization. This is how they create "maps" of active damage in a pressure vessel.
Now, we all know nuclear disasters have happened. However they were essentially all caused by unanticipated system level events like operator error, flawed reactor design, loss of cooling systems, or external events (e.g., earthquakes and tsunamis), but none of them have come from embrittlement. Some came close, prompting new legislation, but none have failed. This is largely because these systems are designed and operated with extremely conservative safety margins. Engineers anticipate embrittlement from the outset and account for it by selecting robust materials, limiting neutron exposure where possible, and continuously monitoring changes in fracture toughness through surveillance programs inside the reactor. Operational procedures are also tightly controlled to avoid dangerous conditions like rapid cooling under high pressure (pressurized thermal shock), and regulatory frameworks enforce strict limits on allowable embrittlement over a reactor’s lifetime. When vessels approach these limits, operators can reduce power, modify operating conditions, perform thermal annealing to partially reverse damage, or retire the plant altogether. I think we can learn from this. To motivate thinking about these sorts of early warning systems in AI, we will give two examples below, one very near term and one that is more near future.
Training a Legal Assistant
Consider a large language model being trained as a legal assistant. The training corpus is carefully assembled, containing case law, legal commentary, judicial opinions, law review articles. It's been filtered for overtly offensive content, and overtly racist content. The trainers have done their due diligence. But the corpus faithfully represents the actual history of the legal system, and that history is saturated with structural biases that don't announce themselves as bias. Sentencing opinions from the 20th century don't contain slurs, but they systematically use different affect language when describing defendants of different races such as "showed remorse" versus "appeared unaffected," "troubled upbringing" versus "criminal environment." Housing law cases don't explicitly endorse redlining, but the exposed reasoning patterns treat certain neighborhoods as inherently higher-risk in ways that perfectly correlate with racial composition.
Early in training, the model learns the broad structure of legal reasoning. Argument forms, precedent hierarchies, the grammar of judicial opinions. This is the coarse-grained, high-value signal in the data, and it's largely benign. But as training continues and the model begins fitting finer-grained patterns, it starts picking up the subtler statistical regularities such as the correlations between affect language and defendant demographics, the way "sound legal reasoning" and "emotional appeal" map onto different communities, which framings are treated as authoritative in which contexts. These are the higher-frequency patterns in the data, and they carry the historical prejudice not as explicit content but as texture. There is a transition point where the model is not learning general legal arguments but instead is learning higher frequency features, which can include how legal arguments have historically worked for and against particular groups.
The thing is, if one is simply looking at benchmark scores, it can be a long time before this behaviour starts to show up. If you are only checking at the end of training, and you find your model is now racially biased, that is a lot of money down the drain on training, and money that has been wasted on data collection as now you have to find some way to design a new dataset that does not produce this sort of prejudice which is highly non-trivial. Instead, it would be better to be able to detect such behaviours starting to form and be able to intervene on the model. In general, it's not just these sorts of undesirable behaviours we care about, but also emergent capabilities, goals, etc. as actually having a map of what the model has developed during training makes it much easier to assess the risk the model poses in deployment.
A Continual Learning Agent
For a plausible, near future example consider a future AI system that learns continuously during deployment. Imagine an AI medical triage assistant deployed in an emergency department, continuously fine-tuning on the cases it encounters and for simplicity assume we do some good theoretical work and it starts well-calibrated and aligned. But over months, the distribution of cases shifts as flu season arrives, a new respiratory virus circulates, the local demographics change as a new elderly care facility opens nearby, etc. The agent keeps updating during all of this. At first this is beneficial: it becomes better at recognizing the patterns it's seeing more frequently. But gradually, two things happen in parallel.
First, it becomes subtly worse at recognizing rare presentations it hasn't seen in a while (the analog of catastrophic forgetting, but partial and hard to detect). Second, and more perniciously, the sheer volume of correlated cases during flu season warps its internal representations so that it starts over-attributing respiratory symptoms even in cases where they're incidental. Its "worldview" has silently narrowed.
The key dramatic element is that nothing looks wrong from the outside for a long time. The aggregate accuracy metrics remain high because most cases are respiratory during flu season. The degradation is hidden in the rare cardiac case that gets deprioritized, in the unusual neurological presentation that gets misread through a respiratory lens. By the time someone notices a pattern of missed diagnoses, the internal reorganization is already deep. Obviously, one could construct more catastrophic scenarios than this but it illustrates the point. One might say "we can guarantee that this won't happen by designing the agent correctly in the first place!" which might in principle be true but it feels risky to just trust that everything is working the way we had designed it to. No one wants a nuclear power plant without continuous monitoring.
Working On Early Warning Signs
One thing I found rather surprising is that there has been almost no work done on early detection of phase transitions in deep learning despite this being a core component of many safety-critical domains. Climate science has decades of work on detecting tipping points via critical slowing down, ecology has operational frameworks for detecting regime shifts in ecosystems, Structural engineering has entire industries built around structural health monitoring, etc. In some ways it makes sense that this hasn't happened yet in deep learning as the field is younger, and the very concept of "phase transitions during training" as something worth detecting is relatively recent. It is then a natural next step to try and detect these phase transitions as early as possible.
Phase Transitions and Early Warning Signals
The first research objective to try and tackle for developing early warning methods is "How do we detect emergent behaviours/capabilities in training before they actually solidify?" This question seems straightforward, but there are some important subtleties that need to be addressed. Namely, what kinds of phase transitions we actually see in training, and what it means to detect them early.
When is Early Detection Possible?
Types of Phases
In physics, when discussing phase transitions we tend to talk primarily about first order and second order phase transitions (though there are other types, they are far less common). The order of a phase transition is determined by the behaviour of the free energy where is the energy, is the entropy, is some free parameter, and is a temperature parameter.
A first order phase transition (in ) occurs when itself is continuous, but its first derivative has a jump discontinuity:
That is:
The free energy has a kink (corner) at .
A second order phase transition is one where and are both continuous, but the second derivative has a jump discontinuity:
That is:
The free energy is smooth through the transition but its curvature changes discontinuously. This corresponds to a discontinuity in quantities like heat capacity or susceptibility, which are second-derivative observables.
Early Warning Signals
Consider simple gradient dynamics where the system evolves to minimize :
Near a stable fixed point (where ), linearize with :
The solution is exponential relaxation with timescale:
So the relaxation time is set directly by the curvature of the free energy at the fixed point.
At a second-order transition, recall that is continuous but vanishes at . As , the minimum of becomes progressively flatter — the restoring force weakens. Therefore:
This divergence is continuous and gradual. The system telegraphs the approaching transition through observable precursors long before is reached:
Relaxation time grows (critical slowing down proper)
Variance of fluctuations grows (since by the fluctuation-dissipation theorem, )
Autocorrelation time grows
All of these are measurable and diverge with a power law in . This is why second-order transitions admit early warning signals. The free energy landscape is continuously preparing the transition by flattening out, and the dynamics faithfully reflect this.
At a first-order transition, jumps discontinuously. Geometrically, the picture is fundamentally different: there are two distinct local minima, and the transition occurs when they exchange global stability (i.e. when ).
The crucial point is that neither minimum needs to flatten out at the transition. Each basin can retain healthy positive curvature right up to , so:
The system simply jumps from one basin to the other with no dynamical precursor. The relaxation time does not diverge, variance does not blow up, and the landscape gives no warning.
So, for first order transitions it is not possible to perform early detection. However, for a certain type of first order transition you can get close. Namely, for nucleation type transitions. Nucleation is the process by which a new phase forms locally within a metastable parent phase by the creation of an object called a critical nucleus, which is a state which is itself not favourable in the potential energy landscape (it is effectively a maxima), but which through some mechanism (usually thermal fluctuations) can climb the free energy barrier, and then descend down to a stable state on the other side of the barrier. The critical nucleus represents a rare, highly ordered configuration of the system, and the entropy term strongly disfavors it in part because there are vastly more microscopic states without a coherent nucleus than with one.
While very difficult in practice for physical systems, one can detect when the critical nucleus forms, which is technically not an early warning in the critical slowing down sense but can see that it is conceptually very similar.
Types of Phase Transitions in Deep Learning
Now, there is a subtle point worth mentioning here which is that deep learning is not a thermodynamic system, and is instead a dynamical one which does not really have thermodynamic phase transitions, and is instead more of a metastable state to metastable state type of transition, so the classical thermodynamic free energy technically isn't defined between these states in most cases. However, dynamical systems tends to still talk about first and second order phase transitions as there are reasonable ways to define the free energy for these things, so we are just going to continue using the thermodynamic language.
So the first thing is to ask how one defines the free energy for deep learning. There is the sort of SLT method for deriving a thermodynamic free energy, which is nice and rigorous but is overkill for the purposes of this post so we are going to go for something much more handwavey. Instead, let's take a parameter choice as having a free energy of where is the loss and is the configurational entropy. What do I mean by configurational entropy? For simplicity, we can consider it some value determined by the volume of points near with loss . So, the configurational entropy is basically how hard it is to specify a parameter which makes the model behave similar to .
Deep learning can have both first and second order phase transitions. We go over mechanisms for both below.
The Second Order Transition
Imagine the model sits in a broad, largely flat region of the loss landscape. That is, this area has high configurational entropy (the level sets are broad, many configurations give similar loss) and the system bumbles about this area due to fluctuations from SGD noise, weight decay, etc. So the system is exploring along these flat directions. Eventually, some subset of these exploratory movements starts to reinforce itself as small modifications are made to an existing structure which slowly increase the curvature of the region, locking it in more. This is second order as the high-dimensional flat region gradually narrows as the system commits to specific directions.
The First Order Transition
The first order transition is interesting, as it's seemingly why learning rate warmup matters and actually is mathematically well-understood in deep linear networks as the silent alignment effect. This occurs at areas of the loss landscape that look like river valleys; walls swooping upward, with a meandering river sloping downward towards something like a basin. The reason why this is first order is that the transition must overcome the entropic barrier of finding the valley, as it makes up a small area of parameter space. Once found, the model locks on to these directions and amplifies it downwards towards a basin.
Intuitively this makes sense when compared with how we tend to think about feature formation. The optimizer finds some feature that decreases the loss, then amplifies that feature within the network.
While it is not clear that these are the only mechanism by which these kinds of phase transitions happen, they provide a good basis for thinking about early detection.
Early Warning Signals for Deep Learning
To motivate how we do this in the paper, we can sort of reason through the thought process together. Imagine watching a crowd of people milling around a plaza from above. Most of the time, their movements are essentially independent. But if something interesting happens (say a street performer starts doing something), a subset of people start moving in the same direction at the same time. That is, there are hints that the system is reorganizing as the movement between people starts to become coordinated.
Now, let's move this reasoning over to the deep learning case. If you ask "what's the cheapest thing I could compute that would tell me a model is learning something structurally new?" you'd reason as follows: a new capability or circuit, by definition, affects multiple data points in a systematic way as that's what makes it a capability rather than memorization of a single example. So the emergence of a new capability should manifest as a new shared factor explaining the loss variation across samples. This then can be thought of geometrically. When samples start to fluctuate together the loss landscape develops (or the optimizer discovers) directions that coherently affect groups of samples. To see that these fluctuations are in some sense the earliest thing we can detect, consider that the transition is the optimizer moving a long way along a coupled direction. But the coupling in the landscape has to exist before the optimizer can exploit it. So (under some very mild assumptions that you essentially don't make your optimizer too cursed) there's a period where the optimizer discovers and starts moving along the coupled direction (producing correlated loss fluctuations), but hasn't yet moved far enough for the capability to actually work. Here you will get correlated loss fluctuations but not necessarily any detectable loss decrease.
The next important thing to realize is that the loss landscape carries the "memories" of what has been learned. That is, suppose at some point your model learns some independent skill, like basic addition. The model parameters associated with basic addition are largely locked in, so the optimizer never really explores that direction in gradient space. Since, by definition, no learning is taking place there, there is no reason to track that area anymore. If it then needs to be tracked again because learning is happening there, well then by definition the optimizer knows about it. This might seem rather trivial, but this tells us that for the purpose of tracking phase transitions the optimizer's trajectory is a natural importance sampler over the relevant directions in parameter space. This isn't to say that there isn't value in looking at the local shape of the loss landscape more broadly, but simply that for detecting emergent capabilities the optimization directions contain a sufficient amount of information.
This is similar to the idea of reaction coordinates. In chemistry, most degrees of freedom are spectators. Solvent molecules jiggling around, bond vibrations that aren't involved in the reaction. If you want to study the progress of a reaction, you look at the coordinates of the system the reaction is taking place in.
The 2-Datapoint Reduced Density Matrix
I feel like saying we "introduce" this object is kind of weird. First, it's actually a very natural object if you studied physics and furthermore is really sort of a generalization of the loss kernel. First, we pick a set of datapoints called the probe set, and denote the loss for some model parametrized by on sample as . Let be some distribution about some local area are of parameter space. The 2RDM (under ) is then defined as the matrix with entries given by:
Now, in the paper we primarily work with 2 specific ways of defining . First is the Gaussian 2RDM where is a narrow Gaussian centered about some parameter . This version of the 2RDM is useful if one needs to probe very early in training (within the first few steps).
The other one we introduce and that we mostly focus on is the dynamical 2RDM. This version effectively takes advantage of the natural importance sampling of the optimizer and is dirt cheap to compute. Effectively, one takes to be a uniform distribution defined over a window of SGD steps. In practice, at set intervals we compute the per sample losses on each sample in the probe set, and keep track of these per sample losses, building a sliding window of per sample losses. From this, we compute the covariance matrix. This does not require any expensive sampling of model parameters. This avoids massive computational overhead (as it simply costs forward passes plus something like to compute the covariance with the number of samples and the size of the window) and actually largely avoids the curse of dimensionality as it naturally samples the relevant geometry impacting learning.
Properties of the 2RDM
In the paper, we present some properties of the 2RDM and then derive a bunch of formal results proving they do the thing we care about. They rely on things like the breaking of linear approximations, majorization, and a bunch of other stuff, but the intuition of what is going on is much more straightforward than these results make it seem.
The first thing we introduce is this quantity called the spectral heat capacity, which is the variance of the eigenvalues of the covariance matrix . So
we also make use of a quantity called the participation ratio
The main claim of the paper can be boiled down to saying that the spectral heat capacity acts as an early warning signal for phase transitions and that the participation ratio provides information about the dynamics of the transition. The paper spends a lot of time formally proving the claim about the SHC (as the participation ratio is almost definitional). However, at least for the river valley first order transition, we can give rather intuitive geometric reasoning why this works, which we give below.
The eigenvectors of effectively tell us what samples are moving together. The eigenvalue of an eigenvector tells you the magnitude of their collective fluctuation. If all the eigenvalues of that covariance are roughly equal, losses are fluctuating in many independent directions so nothing coordinated is happening. But when one eigenvalue starts dwarfing the rest, it means many examples' losses are suddenly dancing in lockstep along a single direction.
In the case of the river valley, the transition happens as the network rotates to move onto a new direction, which is effectively a sudden first order phase transition. The losses co-fluctuate here and are picked up by the SHC before they amplify onto the newly aligned direction. In the paper, we observe this directly in deep linear networks.
For second order transitions, what we can say is that the landscape curvature is softening the model is losing its "grip" along certain directions, and the SHC detects that weakening grip through the growing variance it produces, but an exact phenomenological picture is hard to paint here intuitively.
The participation ratio does allow one to shed some light on this in practice as the PR can increase either before or after the SHC. The SHC is a rigorous precursor signal for a transition. The participation ratio isn't, but it does tell you how many modes are fluctuating together. This helps one determine how localized the transition is. A small PR means the transition is localized along one mode, where a high PR means it's many modes at once. Furthermore, the temporal order of SHC-PR change tells us about the transition. If the SHC spikes before the PR, this means the landscape is funneling into a single narrow bottleneck where one direction dominates first, then after passing through, additional directions activate. If the PR increases first, This means multiple directions are softening simultaneously. The model is sitting on a broad plateau where many things are in play at once. Then eventually the symmetry between those directions breaks, one pulls ahead, and the SHC fires.
Experimental Results, Probe Selection, and Scaling
For a full report on the experimental results, I suggest checking the paper as there is a lot to cover and I don't want to make this blog post any longer than it really has to. In short, experimentally we found that the 2RDM behaved basically exactly as predicted and worked reliably as an early detection method on all the settings we tested which essentially entirely consisted of known phase transitions during training that we could replicate and detect directly, then see where the SHC spiked relative to the known location of the phase transition. In basically all of these settings it worked with 0 tuning, straight out of the box. This worked for:
Deep linear networks
Grokking
Induction head formation
Emergent misalignment
This is good, as it means it does what we expect. However, I try and make it clear in the paper that there is a subtle issue with this, in that the way you select you probe set matters and that this isn't unique to the 2RDM but actually any interpretable phase transition detection method. There's formal ways to show this but it's far dumber than that. One simply needs to ask "how can I tell the model behaviour changed if I cannot observe it"?
In the examples most common in the study of phase transitions in deep learning, this isn't a problem because the data is seemingly simple enough that a relatively small random sample captures most of the variance. This however is not the case as we scale up, so one needs to intelligently select what they care about observing and what samples provide sufficient signal to actually capture the change we care about. This sort of basis problem is common in various areas of science. In the paper we discuss it in terms of quantum chemistry where one must select finite basis elements to study what is effectively an infinite dimensional Hilbert space. During bond breaking events, using the wrong basis set can cause you to totally miss the fact that anything happened at all. They have methods for dealing with this sort of problem there so hopefully we can do something similar.
Conclusion
Part of the reason for sharing this is that the dynamical 2RDM is very easy to run experiments on, and I figured those who are interested could probably do some interesting things with it, find shortcomings, etc. Maybe it's actually completely useless, but I thought it was pretty neat. Anyway, questions and comments welcome.
This post is sort of meant to provide an explanation of the core ideas of a new preprint on the early detection of phase transitions in deep learning. The preprint could be cleaned up a bit, but I was very excited to share it so decided to share it in its current state. This post explains the core idea of the paper and why we figured this was an important direction.
Introduction
Nuclear Engineering
When I was in my final year of high school, I went through a short phase where I wanted to be a nuclear engineer. There were two main reasons for this; first being that I thought it would be cool to work on developing fusion reactors, and second being that even if I couldn't do that, I did live actually pretty close to one so at least I would have job opportunities! While my interest in this was rather short-lived, it did make me appreciate that a huge amount of nuclear engineering is making nuclear power plants not fail catastrophically.
At the heart of a fission plant sits the reactor pressure vessel. This is where the nuclear reaction actually occurs, and is thus the primary safety barrier. Furthermore, they can't really be replaced and thus are the object that defines the operational lifetime of the plant. The lifetime of an RPV is determined by embrittlement.
What is embrittlement? Inside the reactor, the RPV is constantly bombarded by fast neutrons which can knock atoms out of their lattice positions which create point defects, which then trigger microstructural changes that accumulate over years. Like most things, the properties of steels depends rather highly on their temperature. In the RPV, you want the walls to be ductile so they deform under stress instead of cracking. This requires however that the system sit above a certain temperature called the ductile–brittle transition temperature (DBTT), as above there the metal is more amenable to bending.
However, over many years the small microstructural changes raise this temperature, meaning that at the effective operating temperature the RPV keeps it closer to the brittle behaviour than the ductile behaviour. The RPV experiences significant thermal and mechanical stresses, and the more brittle it becomes, the more likely it is to experience a catastrophic failure (large cracks developing).
So how do they avoid this becoming a problem? Well, first they predict the expected lifetime of the reactor from from radiation exposure models. It would be rather irresponsible however to just do this computation let it ride. No, instead what they do is continuously monitor the reactor to validate their models and detect unexpected changes in the behaviour of the system.
How do they do this? A modern way is using acoustic emission monitoring. Consider many tiny microscopic events in a material: a crack tip advances by a tiny increment, a dislocation slips, a grain boundary shears, a new phase nucleates. Each of these events involves a sudden, localized release of stored elastic energy. These events produce transient elastic waves (stress waves) that propagate outward through the material at the speed of sound in that medium, which can be picked up by probing sensors placed strategically throughout the material. If you have multiple sensors, you can locate the source by measuring the difference in arrival times at different sensors. If a wave arrives at sensor A 20 microseconds before sensor B, and you know the wave speed, you can calculate the source position. With enough sensors you get 2D or 3D localization. This is how they create "maps" of active damage in a pressure vessel.
Now, we all know nuclear disasters have happened. However they were essentially all caused by unanticipated system level events like operator error, flawed reactor design, loss of cooling systems, or external events (e.g., earthquakes and tsunamis), but none of them have come from embrittlement. Some came close, prompting new legislation, but none have failed. This is largely because these systems are designed and operated with extremely conservative safety margins. Engineers anticipate embrittlement from the outset and account for it by selecting robust materials, limiting neutron exposure where possible, and continuously monitoring changes in fracture toughness through surveillance programs inside the reactor. Operational procedures are also tightly controlled to avoid dangerous conditions like rapid cooling under high pressure (pressurized thermal shock), and regulatory frameworks enforce strict limits on allowable embrittlement over a reactor’s lifetime. When vessels approach these limits, operators can reduce power, modify operating conditions, perform thermal annealing to partially reverse damage, or retire the plant altogether. I think we can learn from this. To motivate thinking about these sorts of early warning systems in AI, we will give two examples below, one very near term and one that is more near future.
Training a Legal Assistant
Consider a large language model being trained as a legal assistant. The training corpus is carefully assembled, containing case law, legal commentary, judicial opinions, law review articles. It's been filtered for overtly offensive content, and overtly racist content. The trainers have done their due diligence. But the corpus faithfully represents the actual history of the legal system, and that history is saturated with structural biases that don't announce themselves as bias. Sentencing opinions from the 20th century don't contain slurs, but they systematically use different affect language when describing defendants of different races such as "showed remorse" versus "appeared unaffected," "troubled upbringing" versus "criminal environment." Housing law cases don't explicitly endorse redlining, but the exposed reasoning patterns treat certain neighborhoods as inherently higher-risk in ways that perfectly correlate with racial composition.
Early in training, the model learns the broad structure of legal reasoning. Argument forms, precedent hierarchies, the grammar of judicial opinions. This is the coarse-grained, high-value signal in the data, and it's largely benign. But as training continues and the model begins fitting finer-grained patterns, it starts picking up the subtler statistical regularities such as the correlations between affect language and defendant demographics, the way "sound legal reasoning" and "emotional appeal" map onto different communities, which framings are treated as authoritative in which contexts. These are the higher-frequency patterns in the data, and they carry the historical prejudice not as explicit content but as texture. There is a transition point where the model is not learning general legal arguments but instead is learning higher frequency features, which can include how legal arguments have historically worked for and against particular groups.
The thing is, if one is simply looking at benchmark scores, it can be a long time before this behaviour starts to show up. If you are only checking at the end of training, and you find your model is now racially biased, that is a lot of money down the drain on training, and money that has been wasted on data collection as now you have to find some way to design a new dataset that does not produce this sort of prejudice which is highly non-trivial. Instead, it would be better to be able to detect such behaviours starting to form and be able to intervene on the model. In general, it's not just these sorts of undesirable behaviours we care about, but also emergent capabilities, goals, etc. as actually having a map of what the model has developed during training makes it much easier to assess the risk the model poses in deployment.
A Continual Learning Agent
For a plausible, near future example consider a future AI system that learns continuously during deployment. Imagine an AI medical triage assistant deployed in an emergency department, continuously fine-tuning on the cases it encounters and for simplicity assume we do some good theoretical work and it starts well-calibrated and aligned. But over months, the distribution of cases shifts as flu season arrives, a new respiratory virus circulates, the local demographics change as a new elderly care facility opens nearby, etc. The agent keeps updating during all of this. At first this is beneficial: it becomes better at recognizing the patterns it's seeing more frequently. But gradually, two things happen in parallel.
First, it becomes subtly worse at recognizing rare presentations it hasn't seen in a while (the analog of catastrophic forgetting, but partial and hard to detect). Second, and more perniciously, the sheer volume of correlated cases during flu season warps its internal representations so that it starts over-attributing respiratory symptoms even in cases where they're incidental. Its "worldview" has silently narrowed.
The key dramatic element is that nothing looks wrong from the outside for a long time. The aggregate accuracy metrics remain high because most cases are respiratory during flu season. The degradation is hidden in the rare cardiac case that gets deprioritized, in the unusual neurological presentation that gets misread through a respiratory lens. By the time someone notices a pattern of missed diagnoses, the internal reorganization is already deep. Obviously, one could construct more catastrophic scenarios than this but it illustrates the point. One might say "we can guarantee that this won't happen by designing the agent correctly in the first place!" which might in principle be true but it feels risky to just trust that everything is working the way we had designed it to. No one wants a nuclear power plant without continuous monitoring.
Working On Early Warning Signs
One thing I found rather surprising is that there has been almost no work done on early detection of phase transitions in deep learning despite this being a core component of many safety-critical domains. Climate science has decades of work on detecting tipping points via critical slowing down, ecology has operational frameworks for detecting regime shifts in ecosystems, Structural engineering has entire industries built around structural health monitoring, etc. In some ways it makes sense that this hasn't happened yet in deep learning as the field is younger, and the very concept of "phase transitions during training" as something worth detecting is relatively recent. It is then a natural next step to try and detect these phase transitions as early as possible.
Phase Transitions and Early Warning Signals
The first research objective to try and tackle for developing early warning methods is "How do we detect emergent behaviours/capabilities in training before they actually solidify?" This question seems straightforward, but there are some important subtleties that need to be addressed. Namely, what kinds of phase transitions we actually see in training, and what it means to detect them early.
When is Early Detection Possible?
Types of Phases
In physics, when discussing phase transitions we tend to talk primarily about first order and second order phase transitions (though there are other types, they are far less common). The order of a phase transition is determined by the behaviour of the free energy where is the energy, is the entropy, is some free parameter, and is a temperature parameter.
A first order phase transition (in ) occurs when itself is continuous, but its first derivative has a jump discontinuity:
That is:
The free energy has a kink (corner) at .
A second order phase transition is one where and are both continuous, but the second derivative has a jump discontinuity:
That is:
The free energy is smooth through the transition but its curvature changes discontinuously. This corresponds to a discontinuity in quantities like heat capacity or susceptibility, which are second-derivative observables.
Early Warning Signals
Consider simple gradient dynamics where the system evolves to minimize :
Near a stable fixed point (where ), linearize with :
The solution is exponential relaxation with timescale:
So the relaxation time is set directly by the curvature of the free energy at the fixed point.
At a second-order transition, recall that is continuous but vanishes at . As , the minimum of becomes progressively flatter — the restoring force weakens. Therefore:
This divergence is continuous and gradual. The system telegraphs the approaching transition through observable precursors long before is reached:
All of these are measurable and diverge with a power law in . This is why second-order transitions admit early warning signals. The free energy landscape is continuously preparing the transition by flattening out, and the dynamics faithfully reflect this.
At a first-order transition, jumps discontinuously. Geometrically, the picture is fundamentally different: there are two distinct local minima, and the transition occurs when they exchange global stability (i.e. when ).
The crucial point is that neither minimum needs to flatten out at the transition. Each basin can retain healthy positive curvature right up to , so:
The system simply jumps from one basin to the other with no dynamical precursor. The relaxation time does not diverge, variance does not blow up, and the landscape gives no warning.
So, for first order transitions it is not possible to perform early detection. However, for a certain type of first order transition you can get close. Namely, for nucleation type transitions. Nucleation is the process by which a new phase forms locally within a metastable parent phase by the creation of an object called a critical nucleus, which is a state which is itself not favourable in the potential energy landscape (it is effectively a maxima), but which through some mechanism (usually thermal fluctuations) can climb the free energy barrier, and then descend down to a stable state on the other side of the barrier. The critical nucleus represents a rare, highly ordered configuration of the system, and the entropy term strongly disfavors it in part because there are vastly more microscopic states without a coherent nucleus than with one.
While very difficult in practice for physical systems, one can detect when the critical nucleus forms, which is technically not an early warning in the critical slowing down sense but can see that it is conceptually very similar.
Types of Phase Transitions in Deep Learning
Now, there is a subtle point worth mentioning here which is that deep learning is not a thermodynamic system, and is instead a dynamical one which does not really have thermodynamic phase transitions, and is instead more of a metastable state to metastable state type of transition, so the classical thermodynamic free energy technically isn't defined between these states in most cases. However, dynamical systems tends to still talk about first and second order phase transitions as there are reasonable ways to define the free energy for these things, so we are just going to continue using the thermodynamic language.
So the first thing is to ask how one defines the free energy for deep learning. There is the sort of SLT method for deriving a thermodynamic free energy, which is nice and rigorous but is overkill for the purposes of this post so we are going to go for something much more handwavey. Instead, let's take a parameter choice as having a free energy of where is the loss and is the configurational entropy. What do I mean by configurational entropy? For simplicity, we can consider it some value determined by the volume of points near with loss . So, the configurational entropy is basically how hard it is to specify a parameter which makes the model behave similar to .
Deep learning can have both first and second order phase transitions. We go over mechanisms for both below.
The Second Order Transition
Imagine the model sits in a broad, largely flat region of the loss landscape. That is, this area has high configurational entropy (the level sets are broad, many configurations give similar loss) and the system bumbles about this area due to fluctuations from SGD noise, weight decay, etc. So the system is exploring along these flat directions. Eventually, some subset of these exploratory movements starts to reinforce itself as small modifications are made to an existing structure which slowly increase the curvature of the region, locking it in more. This is second order as the high-dimensional flat region gradually narrows as the system commits to specific directions.
The First Order Transition
The first order transition is interesting, as it's seemingly why learning rate warmup matters and actually is mathematically well-understood in deep linear networks as the silent alignment effect. This occurs at areas of the loss landscape that look like river valleys; walls swooping upward, with a meandering river sloping downward towards something like a basin. The reason why this is first order is that the transition must overcome the entropic barrier of finding the valley, as it makes up a small area of parameter space. Once found, the model locks on to these directions and amplifies it downwards towards a basin.
Intuitively this makes sense when compared with how we tend to think about feature formation. The optimizer finds some feature that decreases the loss, then amplifies that feature within the network.
While it is not clear that these are the only mechanism by which these kinds of phase transitions happen, they provide a good basis for thinking about early detection.
Early Warning Signals for Deep Learning
To motivate how we do this in the paper, we can sort of reason through the thought process together. Imagine watching a crowd of people milling around a plaza from above. Most of the time, their movements are essentially independent. But if something interesting happens (say a street performer starts doing something), a subset of people start moving in the same direction at the same time. That is, there are hints that the system is reorganizing as the movement between people starts to become coordinated.
Now, let's move this reasoning over to the deep learning case. If you ask "what's the cheapest thing I could compute that would tell me a model is learning something structurally new?" you'd reason as follows: a new capability or circuit, by definition, affects multiple data points in a systematic way as that's what makes it a capability rather than memorization of a single example. So the emergence of a new capability should manifest as a new shared factor explaining the loss variation across samples. This then can be thought of geometrically. When samples start to fluctuate together the loss landscape develops (or the optimizer discovers) directions that coherently affect groups of samples. To see that these fluctuations are in some sense the earliest thing we can detect, consider that the transition is the optimizer moving a long way along a coupled direction. But the coupling in the landscape has to exist before the optimizer can exploit it. So (under some very mild assumptions that you essentially don't make your optimizer too cursed) there's a period where the optimizer discovers and starts moving along the coupled direction (producing correlated loss fluctuations), but hasn't yet moved far enough for the capability to actually work. Here you will get correlated loss fluctuations but not necessarily any detectable loss decrease.
The next important thing to realize is that the loss landscape carries the "memories" of what has been learned. That is, suppose at some point your model learns some independent skill, like basic addition. The model parameters associated with basic addition are largely locked in, so the optimizer never really explores that direction in gradient space. Since, by definition, no learning is taking place there, there is no reason to track that area anymore. If it then needs to be tracked again because learning is happening there, well then by definition the optimizer knows about it. This might seem rather trivial, but this tells us that for the purpose of tracking phase transitions the optimizer's trajectory is a natural importance sampler over the relevant directions in parameter space. This isn't to say that there isn't value in looking at the local shape of the loss landscape more broadly, but simply that for detecting emergent capabilities the optimization directions contain a sufficient amount of information.
This is similar to the idea of reaction coordinates. In chemistry, most degrees of freedom are spectators. Solvent molecules jiggling around, bond vibrations that aren't involved in the reaction. If you want to study the progress of a reaction, you look at the coordinates of the system the reaction is taking place in.
The 2-Datapoint Reduced Density Matrix
I feel like saying we "introduce" this object is kind of weird. First, it's actually a very natural object if you studied physics and furthermore is really sort of a generalization of the loss kernel. First, we pick a set of datapoints called the probe set, and denote the loss for some model parametrized by on sample as . Let be some distribution about some local area are of parameter space. The 2RDM (under ) is then defined as the matrix with entries given by:
Now, in the paper we primarily work with 2 specific ways of defining . First is the Gaussian 2RDM where is a narrow Gaussian centered about some parameter . This version of the 2RDM is useful if one needs to probe very early in training (within the first few steps).
The other one we introduce and that we mostly focus on is the dynamical 2RDM. This version effectively takes advantage of the natural importance sampling of the optimizer and is dirt cheap to compute. Effectively, one takes to be a uniform distribution defined over a window of SGD steps. In practice, at set intervals we compute the per sample losses on each sample in the probe set, and keep track of these per sample losses, building a sliding window of per sample losses. From this, we compute the covariance matrix. This does not require any expensive sampling of model parameters. This avoids massive computational overhead (as it simply costs forward passes plus something like to compute the covariance with the number of samples and the size of the window) and actually largely avoids the curse of dimensionality as it naturally samples the relevant geometry impacting learning.
Properties of the 2RDM
In the paper, we present some properties of the 2RDM and then derive a bunch of formal results proving they do the thing we care about. They rely on things like the breaking of linear approximations, majorization, and a bunch of other stuff, but the intuition of what is going on is much more straightforward than these results make it seem.
The first thing we introduce is this quantity called the spectral heat capacity, which is the variance of the eigenvalues of the covariance matrix . So
we also make use of a quantity called the participation ratio
The main claim of the paper can be boiled down to saying that the spectral heat capacity acts as an early warning signal for phase transitions and that the participation ratio provides information about the dynamics of the transition. The paper spends a lot of time formally proving the claim about the SHC (as the participation ratio is almost definitional). However, at least for the river valley first order transition, we can give rather intuitive geometric reasoning why this works, which we give below.
The eigenvectors of effectively tell us what samples are moving together. The eigenvalue of an eigenvector tells you the magnitude of their collective fluctuation. If all the eigenvalues of that covariance are roughly equal, losses are fluctuating in many independent directions so nothing coordinated is happening. But when one eigenvalue starts dwarfing the rest, it means many examples' losses are suddenly dancing in lockstep along a single direction.
In the case of the river valley, the transition happens as the network rotates to move onto a new direction, which is effectively a sudden first order phase transition. The losses co-fluctuate here and are picked up by the SHC before they amplify onto the newly aligned direction. In the paper, we observe this directly in deep linear networks.
For second order transitions, what we can say is that the landscape curvature is softening the model is losing its "grip" along certain directions, and the SHC detects that weakening grip through the growing variance it produces, but an exact phenomenological picture is hard to paint here intuitively.
The participation ratio does allow one to shed some light on this in practice as the PR can increase either before or after the SHC. The SHC is a rigorous precursor signal for a transition. The participation ratio isn't, but it does tell you how many modes are fluctuating together. This helps one determine how localized the transition is. A small PR means the transition is localized along one mode, where a high PR means it's many modes at once. Furthermore, the temporal order of SHC-PR change tells us about the transition. If the SHC spikes before the PR, this means the landscape is funneling into a single narrow bottleneck where one direction dominates first, then after passing through, additional directions activate. If the PR increases first, This means multiple directions are softening simultaneously. The model is sitting on a broad plateau where many things are in play at once. Then eventually the symmetry between those directions breaks, one pulls ahead, and the SHC fires.
Experimental Results, Probe Selection, and Scaling
For a full report on the experimental results, I suggest checking the paper as there is a lot to cover and I don't want to make this blog post any longer than it really has to. In short, experimentally we found that the 2RDM behaved basically exactly as predicted and worked reliably as an early detection method on all the settings we tested which essentially entirely consisted of known phase transitions during training that we could replicate and detect directly, then see where the SHC spiked relative to the known location of the phase transition. In basically all of these settings it worked with 0 tuning, straight out of the box. This worked for:
This is good, as it means it does what we expect. However, I try and make it clear in the paper that there is a subtle issue with this, in that the way you select you probe set matters and that this isn't unique to the 2RDM but actually any interpretable phase transition detection method. There's formal ways to show this but it's far dumber than that. One simply needs to ask "how can I tell the model behaviour changed if I cannot observe it"?
In the examples most common in the study of phase transitions in deep learning, this isn't a problem because the data is seemingly simple enough that a relatively small random sample captures most of the variance. This however is not the case as we scale up, so one needs to intelligently select what they care about observing and what samples provide sufficient signal to actually capture the change we care about. This sort of basis problem is common in various areas of science. In the paper we discuss it in terms of quantum chemistry where one must select finite basis elements to study what is effectively an infinite dimensional Hilbert space. During bond breaking events, using the wrong basis set can cause you to totally miss the fact that anything happened at all. They have methods for dealing with this sort of problem there so hopefully we can do something similar.
Conclusion
Part of the reason for sharing this is that the dynamical 2RDM is very easy to run experiments on, and I figured those who are interested could probably do some interesting things with it, find shortcomings, etc. Maybe it's actually completely useless, but I thought it was pretty neat. Anyway, questions and comments welcome.
Thanks! Max