When I consider the trajectory of my life, it has really been determined in large part due to my love of science of all kinds and my rather intense ADHD. When I was younger I had the thought that "most fields of science I care about are largely mathematical, so I should study mathematics!" which led me to doing my undergrad in math, where I fell in love with AI. The thing that particularly excited me was representation learning, and my opinion at the time was that for good representations, one needed to design the system with strong priors. So, I did my M.Sc was mathematics as well, studying group equivariant convolutional neural networks. Then, through a series of events, I ended up doing my PhD in an interdisciplinary studies program where I am part of a quantum chemistry/computing lab.
Why is this relevant at all for a post about feature learning? While I definitely am not chemist, being around people whose entire research careers are dedicated to predicting how complex systems interact to create interesting stuff, you realize many of the problems we consider in theoretical deep learning have been studied for years in some area of the physical sciences with many of the weird particular questions we think are specific to deep learning appearing in some random journal in the 90s, but it's about polymerization or something instead of deep learning. One of the effects this has had on me is that certain things that occur in deep learning that I thought were very mysterious seemingly can be explained rather simply. Now, translating things rigorously is still non-trivial in many cases, which isn't surprising as much of mathematics is simply writing the same thing slightly differently, but physical systems can provide a good starting point. In this post, we are going to look at a somewhat non-rigorous example of this and show that the way structures form in deep learning is actually a rather straightforward consequence of how deep learning works by showing that it shares characteristics with physical systems which share the same behaviour.
Author Note: I doubt I am the first person to have these ideas, as I think they're somewhat of a natural perspective. Also, there's a fair bit of handwaving and oversimplification here to keep the post digestible. I suggest for those interested to check out some of the references at the bottom!
One of the difficulties in discussing structure in neural networks is defining exactly what we mean by structure. We tend to frame things in terms of features and circuits, which are these nebulous concepts which don't yet seem to have well agreed upon definitions, and in a sense are sort of "you know it when you see it" concepts. However, there are working definitions which we will adopt for the sake of this post.
One way to frame this is that features are what is represented, and circuits are how representations are used. However, throughout this post we play a bit fast and loose with the terms. During training, models are optimized to implement computations (circuits); “features” seemingly emerge as reusable subcomponents of those computations, rather than being primary training targets themselves. So, when discussing structure formation we are really talking about circuits but sometimes it's conceptually easier to consider individual features, as the broad ideas still hold.
Unsurprisingly, there has been a fair bit of research surrounding how circuits form in neural networks during training. It has been observed that circuits form according to a sort of 2+1 phase system (that is, two phases seem to be universal with a third phase overwhelmingly common but not universal) like seed -> amplification -> cleanup (or sometimes worded as memorization -> circuit formation -> cleanup in grokking). Before describing these phases, we note that this framing originates from some early work on grokking however the core ideas are older, dating back to work on the information bottleneck in 2017 which show that learning seemingly happens in an information acquisition phase and an information compression phase. There has been debate around the compression phase however most modern research seems to indicate that while not universal, it is a common component of the learning process. Now, with that out of the way, we can give a description of the various phases.
Near initialization, there is something like a symmetry breaking event. That is, when the network is initialized (or warmed up near initialization), one can have many little "proto-features" that seemingly carry some signal from the input that contains some small amount of structure. Early proto-features can have weak links between one another that form proto-circuits. This formation phase happens rather quickly relative to the training time, meaning the core features/circuits, are seeded early in training. Furthermore, there tends to be many roughly equivalent proto-features/circuits at this stage.
Put a bit differently, this stage contains gradients that are noisy and weakly correlated with small, accidental alignments appearing between weights and recurring patterns in the data. These correspond to proto-features: weak detectors that respond slightly more to some structure than others. These have weak links that form proto-circuits. This stage is fragile, easily reversible, and highly dependent on the data and the initialization.
Proto-features, once formed, have a rich-get-richer dynamic. Namely, if two proto-features both detect a similar thing, but one is a stronger detector than the other and reduces the loss more, backprop will then reinforce this stronger feature, making it even stronger and more consistent, forcing any surrounding circuitry to make use of that version over the weaker version. This stage sees frequently useful features grow fast, weak ones get starved, and the gradient flow adjusts and structures itself around these emerging representations. Interestingly, if one looks at the weight displacement early in training, one tends to see super-diffusive behaviour where the weights move very rapidly, which is likely indicative of the strong gradient signals present during this regime. Eventually this phase relaxes as the circuits reach an effectively "formed" state.
This phase is common but not universal. After the model undergoes amplification of features/circuits, there tend to be many spurious or redundant features remaining that had their growth starved out in the amplification phase. After the primary features have formed, regularization effects tend to drive the spurious features away. This tends to be driven by things like weight decay (or other regularizers), which end up overpowering what little predictive power these remaining proto-features might have, pushing them towards inactivity or integrating them into similar pre-existing circuits. On might then consider later stages where different circuits might combine to form more complex circuits. This tends to have a similar 2+1 phase structure where the original formed circuits are the seeds, which combine and amplify, then undergo cleanup in a similar way.
Now, the main claim of this post: This type of structure formation is exactly what one should expect. That is, the fact that features form this way aligns exactly with what one should expect and aligns exactly with a particular class of physical systems, meaning that deep learning systems seemingly belong to something like a universality class (I hesitate to call it a true universality class, since that does have some particular implications) which gives accurate macroscopic descriptions of a host of systems. In the coming sections, we are going examine some of these physical systems, their shared properties, and then provide evidence that training of deep learning models falls within this same class.
Okay it might seem weird to bring up precipitation reactions in a post about training dynamics but trust me, it's relevant. For those unfamiliar (or have forgotten since high school chemistry) a precipitation reaction is a chemical reaction in which two substances dissolved in a liquid (usually water) react to form a solid product that is insoluble in that liquid. This solid is called a precipitate. At the molecular level, ions are freely moving in solution at first, but when the right combination meets, they lock together into a solid phase and separate from the liquid.
Precipitation reactions happen in 3 stages: nucleation -> growth -> rearrangement. Let's consider the example of mixing silver nitrate (AgNO3) and sodium chloride (NaCl) solutions. Individually, both dissolve completely in water. But when mixed, silver ions (Ag+) and chloride ions (Cl-) attract each other strongly and form silver chloride (AgCl). First, nucleation occurs when a few Ag+ and Cl- ions come together to form tiny, stable clusters of solid AgCl that are large enough to survive the thermal fluctuations of the system. Next, growth happens as more ions from the solution attach to these nuclei, making the particles larger, becoming visible as a cloudy white precipitate. Finally, rearrangement occurs as the solid particles slowly reorganize internally. Here, ions shift into lower-energy positions, crystals become more ordered, and small nuclei may dissolve and reattach to larger ones.
This section is somewhat optional, but I think it provides some helpful insight as it turns out one can understand these sorts of reactions from an information theoretic perspective.
Nucleation involves crossing a free-energy barrier, which corresponds to suppressing fluctuations that don’t lead to stable structure, and allowing for the amplification of fluctuations that do. Small ion clusters appear constantly, producing noise, but only those that reach a critical size survive and produce signal. This is mathematically analogous to rare-event selection or rate-distortion tradeoffs. Here, the system is in some sense "deciding" which fluctuations to keep based on when they exceed some signal threshold. This behaviour is more pronounced in complex reactions where there might be many different arrangements of unstable molecules that form but only a couple that can actual propogate a signal strongly enough to create a full reaction.
Once a nucleus exists, growth is effectively information copying. The crystal lattice defines a template, incoming ions lose degrees of freedom and conform to that template so the mutual information between ions increases and long-range order propagates. This looks like an information bottleneck as the system tries to maximize the mutual information between the reactants.
Rearrangement processes (Ostwald ripening, defect annealing, recrystallization) correspond to removing redundant or unstable representations, eliminating defects (inconsistent local encodings), merging small structures into fewer, more stable ones. Information theoretically, representing a solid as millions of tiny, separate crystals is inefficient. It requires a vast amount of information to describe all those separate surface interfaces. The system "compresses" the data by merging them into fewer, larger crystals. This minimizes the surface area (interface information), effectively finding the most concise description of the solid phase.
Many systems initially crystallize into a metastable form (unstable polymorph) and later rearrange into a stable form. The metastable form is like "quick-and-dirty" code. It is easy to write (kinetically accessible) but inefficient (higher energy). The system finds this solution first because it requires less complex information processing to assemble. The transition to the stable polymorph is "refactoring." The system rearranges the atoms into a more complex, denser packing that is more robust (lower energy). This code take takes longer to write (high activation energy) but results in cleaner code.
This framework explains why "rearrangement" is common: the first solution the system finds (nucleation) is rarely the optimal one. The system must iteratively "process" the matter (grow, dissolve, regrow, rearrange) to compute the global optimum (the perfect crystal). There are systems where the dynamics effectively stop at nucleation -> growth, with no meaningful rearrangement phase afterward. This happens when the structure formed during growth is already kinetically locked or thermodynamically final.
Precipitation reactions are not the only sort of systems with this nucleation->growth->rearrangement system. In fact, it is very common. Some examples are:
Biological systems actually use this rather extensively. Some particular examples there are:
There is then a simple question: what properties do these sorts of systems share that make them behave like this? It turns out, there is an answer to this.
This unifying structure comes from the interaction of potential energy landscapes, constraints, and timescale separation rather than from chemistry-specific details.
In general, we consider these systems as having some state variable and some potential energy , and we assume they are in some environment with some thermal noise . The system then evolves in its state space according to the equation with the mobility operator. It turns out that the potential energy of all the systems discussed have the following properties:
One might refer to these dynamics as non-convex thermally activated dynamics, and systems which evolve according to these dynamics almost always display the three phase trend as discussed.
Why is this? Well, it's related to the free energy . For simplicity in order to avoid an overly long discussion we are simply going to write the free energy as where is the entropy. Broadly speaking, we can just consider as the amount of nearby configurations of that have roughly the same energy.
In large deviation theory, one can show that in a gradient driven system the probability of transitioning from some state to some other state is where is the difference in the free energy between the point and the saddle between and (which is known as Arrhenius–Eyring–Kramers law), and is proportional to the thermodynamic noise. Here, nucleation corresponds to saddle points of which are accessed by rare, localized fluctuations that create the necessary critical nuclei.
In some sense, systems explicitly care about minimizing the potential energy. The reason the system maximizes the entropy is not because it has some internal drive to maximize the entropy, it's simply due to the fact that highly entropic states are simply more stable, so the system will, over time, decay into the more entropic states and stay there much longer. Another important point is that nucleation can be slow, so most of the time in real engineering applications things are done to speed up this process, either by favourable initial conditions or some other catalyst.
Growth after nucleation occurs relatively fast. This is because once the saddle point of is crossed, you almost certainly have a very strong signal from the gradient of the potential driving you down the other side towards a stable structure, and the drive is much stronger than the thermal fluctuations, meaning it tends to ride the gradient down to the metastable basin, and is unlikely to be pushed off of the path by thermal fluctuations. That is, during this phase, bifurcations are rare.
After growth, the system is almost definitely not in a global minima and contains defects, strain, domain walls, etc. Mathematically the system enters a flat, rough region of the energy landscape where gradients are small so motion requires collective rearrangements of system components. In this domain dynamics become activated again, driven by small thermal fluctuations, but with many small barriers (either energetic or entropic) and exhibit logarithmic or power-law relaxation (subdiffusion) in the state space (it has been observed in multiple works that deep learning exhibits subdiffusive behaviour late in training). This is portion of the process is driven by noise and is effectively slower because the values of are simply much smaller, meaning transitions are generally rarer.
Combining the above, there is a somewhat unsurprising consequence: Any system evolving under noisy gradient descent on a nonconvex functional with local interactions will generically show nucleation, growth, and rearrangement. That is, the three-stage structure appears when local order is energetically favorable but globally difficult, and noise is just strong enough to escape metastable states in the free energy landscape but too weak to equilibrate quickly.
Interestingly, this three phase structure can be seen as a consequence of degeneracy in the potential energy landscape. Consider some potential critical nuclei formed during a process. This nuclei is not unique and there are many potential nuclei which can be formed. A nucleus picks one of these degenerate minimum and the choice is local and random. This choice then expands outward quickly because all minima are equally good locally (that is, attaching to the nucleus is approximately reasonably energetically favourable anywhere). However, maybe more than one nucleus forms so different regions chose different minima. Interfaces cost energy if they don't "match up".
However, if there was no degeneracy then there is only one minimum so the growth stage immediately produces the final structure and there is no long rearrangement stage. Generally, complex, persistent, heterogeneous structure cannot form without degeneracy. Why?
First, consider a system where there is exactly one unique free energy minima like a simple precipitation reaction where the system forms into a single crystal orientation. During the minimization, structure can form, but it’s boring and short-lived.
Adding degeneracy introduces choice without preference. This way when a system orders locally it must choose one of many equivalent states but different regions can choose independently, and mismatches are unavoidable. Those mismatches are the structure. Or, put differently, degeneracy is the mathematical reason structure can form locally faster than it can be made globally consistent and decay away. This is the relationship between entropy and degeneracy. If every component of a system is totally coupled with every other component, the state of a single component determines all others perfectly, so there is really only one universal structure, so in some sense this structure is trivial. Going the other way, if no components correlate, there is no structure at all, and this again is uninteresting. Interesting structure lives between these two extremes.
This is whu in pattern-forming systems, the observable structure is not the ordered state itself, but the places where locally chosen orders fail to agree globally. This is almost a tautology. If this was not true, you would see not see a pattern at all, you would either uniform structure or uniform randomness. When looking at the system you would see scars, clusters, stripes, etc. These points are exactly the interfaces where incompatible local choices meet. One can see this for instance in how cosmological structures form. The space between them is relatively uninteresting, but the galaxies that scar the cosmological landscape are hubs of structure.
As you no doubt have noticed, none of this is at all about deep learning. So now, let's shift gears back that way for the (probably very obvious) takeaway.
Here we will discuss how SGD fits into this picture, and then discuss how experimental results from SLT agree effectively exactly with this.
Modulo some details, there is a generally accepted framing of SGD as being described by a stochastic differential equation like with the population loss on the entire data distribution, where is the learning rate (or some adaptive preconditioner) and is a noise term which in general behaves like an anisotropic Gaussian.
For almost all data distributions we care about, the loss landscape is highly non-convex. It is well known that the loss landscape is dominated by degenerate saddle points, meaning that poor initialization can force the model to need to overcome what are effectively high entropic barriers.
Furthermore, the existence of deep minima is effectively the cause of effects observed in works on critical periods of learning. Critical periods are common in non-convex thermally activated systems (with deep minima). This is because early on, the system selects what is effectively a "parent phase" which must be stable against thermodynamic fluctuations long enough so that the system can exploit the formed structure (the nucleus). By definition, the system will essentially descend deeper and deeper into this parent phase as new structures form around the initial stable structure. As more stable structures are added which have some degree of interdependence, the size of the thermal fluctuation needed to exit the parent phase becomes so large it would almost certainly never be observed in practice. This causes there to be effectively a critical period in these systems where a thermal fluctuation can actually undo a structure and move the system into a different parent phase, after which point the it becomes very energetically costly to exit the phase.
So, we can see that SGD in general behaves according to non-convex thermally activated dynamics. Now, the claim here is that the seed -> amplification -> cleanup is exactly the same as nucleation -> growth -> rearrangement. This might be obvious to some, but we go through some evidence that this really is the case below.
When training large scale modern deep learning models, one almost never just sets the learning rate and lets it run. Almost all models require an initial warmup phase where the learning rate is set very low and gradually increased to some value. Warmup prevents unstable or destructive updates early in training when the network and optimizer statistics are poorly calibrated allowing initial representations to become stable. In practice, without warmup, training large scale models is effectively impossible.
This behaviour is exactly what one would suspect in a system which relies on nucleation for structure formation. Within deep learning systems, it is common to define the temperature of training as with the learning rate and the batch size. Assuming a fixed batch size then, one can adjust the temperature by adjusting the learning rate.
Imagine now some precipitation reaction like we discussed before. One particular complexity of carrying out these sorts of reactions is that if the temperature of the system is slightly too high, the thermal fluctuations render most nuclei unstable (that is, the size of the critical nucleus needed to resist breaking apart increases with temperature). Interestingly, at high temperature many nuclei will form, but then break apart almost instantly. However, too low of a temperature means the critical nucleus will form but the actual bulk of the reaction takes place very slowly which is also not desirable. So, a common practice is to start the reaction at a lower temperature until a stable nucleus forms, then increase the temperature slowly. This effectively allows the system to discover a stable parent phase, which is deep enough to not be destroyed by thermal fluctuations. Then, one increases the temperature to speed up the reaction.
In fact, across materials science there is a common practice of having a temperature process like cool -> hot -> cool. For instance, in polymer processing, this is done to nucleate crystallites, then reheating to induce a process called "lamellar thickening" (crystal layers grow thicker), then controlled cooling to fix morphology.
This largely mirrors what is done in deep learning. First, one starts with a low learning rate and gradually increases it to allow initial representations to form and not be washed out by noise. Then, over the course of training the learning rate is adjusted so that initial, robust features can amplify quickly, while later stages have a lower learning rate to allow for less stable parameters to be adjusted.
In physical systems, this phase is usually associated with a decrease in (configurational) entropy in favour of a decrease in the potential energy. This decrease in entropy tends to be somewhat shallow initially as the growth of the nucleus is initially limited by its size (in most systems).
While in deep learning there are many ways that the learning rate can be adjusted over time, it is interesting to note that almost all of these have some analogue in materials science that are used for similar reasons as their deep learning counterpart. For example, Closed-loop laser thermal processing is like an adaptive optimizer where different areas of the material are heated differently using a laser according to sensor readings.
The correspondence here is rather direct as once the feature is stable enough to produce a meaningful signal that reduces the loss, the strength of the feature will gradually be increased. The rich-get-richer dynamics seen in circuit formation is mirrored by the fact that if we consider multiple nuclei which need a particular atom type to grow, the largest nuclei is just going to encounter more stuff and get larger. Here, the (configurational) entropy tends to decrease rather rapidly as degrees of freedom get locked up as atoms get locked into nuclei.
Again, the analogy here is straightforward. Once a feature/circuit dominates other features/circuits which do the same/similar thing, the weaker internal structure provides very little value for decreasing the loss. So, there is little to no gradient signal keeping it locked in place, so the gradual effect of noise will either dissipate the structure, or force it to merge into the larger one. This process can be rather slow in many physical systems, as the effect of noise (or in some cases, there might be a very weak signal from the potential) takes a long time to accumulate. One can occasionally have runaway rearrangements where the formation/change in some other structure forces more global reorganization, similar to grokking. This phase is usually associated with an increase in entropy as the thermal noise perturbs the unstable parts of the system, meaning that the long lived configurations will be the ones that are more stable and thus have higher entropy.
While this isn't meant to be a post that discusses SLT in-depth, it is effectively the field which studies these sorts of things, so unsurprisingly it shows up here. Now, here one would normally introduce the "local learning coefficient" and the SLT free energy but I am going to make a massive oversimplication here and say that the negative of the learning coefficient is the entropy, so . While not strictly correct, it is "correct enough" for the purposes here. For those unfamiliar with SLT, I strongly suggest this sequence of blog posts.
One of my favorite papers on structure formation in neural networks is the refined local learning coefficient paper. In this work they show that when training a small language model across various datasets, there is a pattern of an increase in the local learning coefficient, followed by a decrease near the end of training. They even show in almost all the LLC plots that there is generally a three phase structure of a slow increase, a fast increase, followed by a decrease.
This is exactly what one would expect from the nucleation -> growth -> rearrangement picture. That is, we should see the entropy decrease slowly as the initial structure begins to stabalize. This should then be followed by a quicker entropy decrease as the structure eats up degrees of freedom in the system. Once the structures are largely formed, the less stable components are slowly broken apart and reassimilated into the system.
I am not going to go into all of the formal mechanisms here (which are investigated in a fourthcoming paper), but the features and circuits are related to degeneracies and how they evolve over the course of training.
In some sense the main geometric object of SLT is the Fisher information matrix , which is the negative expected Hessian of the log-likelihood (that is, the Hessian of the population loss). In physics, the FIM is strongly related to phase transitions through these things called "order parameters". Order parameters are quantities in physics that describe the collective state of a system, vanishing in a disordered phase but becoming non-zero in an ordered phase. The FIM is related to these as the diagonal elements are effectively variance of the order parameters, and the off-diagonal elements are their covariance. Directions then with large FIM eigenvalues correspond to relevant directions (order parameters) while small eigenvalues correspond to irrelevant / sloppily constrained directions. That is, big eigenvalues are directions where changing the values impacts macroscopic properties of the system, and the small eigenvalues effectively don't change the system. Importantly, the more sloppy directions, the higher the configurational entropy of the system as the system can move freely in those directions without changing the energy.
This can be seen since if is an eigenvector with eigenvalue then then infinitesimal changes in that parameter combination do not change the distribution in any statistically distinguishable way. As a system evolves, stable substructures show up as degenerate directions in the FIM. A stable subsystem is one whose behavior depends on only a few parameter combinations; the Fisher Information Matrix reveals this as degeneracy.
What does this mean for features/circuits? Well, a thing to note is that the FIM structure cannot tell you what structures are or how things link together in most cases (this is the classic basis problem). In physics, this is because at some static equilibrium state, two components of a physical system might not seem related under normal probing strategies since their static statistics factorize (or nearly do), their linear responses are orthogonal, their instantaneous correlations are small or zero or they correspond to distinct normal modes or order parameters. However, structures can still be deeply correlated for many different reasons. They may share hidden constraints, couple nonlinearly, interact through slow variables or they only exchange information dynamically. These types of correlations cannot be revealed by standard static analysis.
However, even these hidden correlations can be determined effectively by dynamics. That is, if perturbing one parameter during training changes the evolution of the other, this means they share some substructure. One can then consider for instance how various components couple and decouple their dynamics over time. Since we can't really see the nucleus forming, this is effectively the next best thing, as it is in some sense the way to identify independent substructures under the data distribution. This gives us a reasonable principle: True independence must hold under some evolutionary, not just at a static snapshot. If two components are truly independent perturbations do not propagate between them, their joint dynamics factorizes and no shared slow modes exist. This means that a good "interpretability solution" would likely benefit from dynamic aspects, not just static. This is partially why I am a big fan of things like the loss kernel as this form of work seems like a good step towards this type of solution.
There's a few things one can take away from this but I think what is probably the most important is that structure development is not some big mysterious thing and that there is already a host of research on similar systems that we can draw on to better understand these structure formation in deep learning. It also highlights that singular learning theory likely plays an important role in such a theory as it provides the mathematical language for understanding degenerate points of the loss landscape which are exactly the things that quantify interesting structure.
A lot of information in this post (in particular, the stuff about chemistry) can be found in almost any physical chemistry text. However, here are some particular references that I think contain a fair bit of useful information.