Crossposted from my personal blog.

Epistemic status: Pretty uncertain. I don’t have an expert level understanding of current views in the science of deep learning about why optimization works but just read papers as an amateur. Some of the arguments I present here might be already either known or disproven. If so please let me know!

There are essentially two fundamental questions in the science of deep learning: 1.) Why are models trainable? And 2.) Why do models generalize? The answer to both of these questions relates to the nature and basic geometry of the loss landscape which is ultimately determined by the computational architecture of the model. Here I present my personal and fairly idiosyncratic and speculative answers to these questions and present what I think are fairly novel answers for both of these questions. Let’s get started.

First let’s think about the first question: Why are deep learning models trainable at all? A-priori, they shouldn’t be. Deep neural networks are fundamentally solving incredibly high dimensional problems with extremely non-convex loss-landscape. The entirety of optimization theory should tell you that, in general, this isn’t possible. You should get stuck in exponentially many local minima before getting close to any decent optimum, and also you should probably just explode because you are using SGD, the stupidest possible optimizer.

So why does it, in fact, work? The fundamental reason is that we are very deliberately not in the general case. Much of the implicit work of deep learning is fiddling with architectures, hyperparameters, initializations, and so forth to ensure that, in the specific case, the loss landscape is benign. To even get SGD working in the first place, there are two fundamental problems which neural networks solve:

First, vanilla SGD assumes unit variance. This can be shown straightforwardly from a simple Bayesian interpretation of SGD and is also why second order optimizers like natural gradients are useful. In effect, SGD fails whenever the variance and curvature of the loss landscape is either too high (SGD will explode) or too small (SGD will grind to a halt and take forever to traverse the landscape). This problem has many names such as ‘vanishing/exploding gradients’, ‘poor conditioning’, ‘poor initialization’, and so on. Many of the tricks to get neural networks to work address this initializing SGD in, and forcing it to stay in the regime of unit variance. Examples of this include:

a.) Careful scaling of initializations by width to ensure that the post-activation distribution is as close to a unit variance Gaussian as possible.

b.) Careful scaling of initialization by depth to ensure that the network stays close to an identity function at initialization and variance is preserved (around 1). This is necessary to prevent either exploding activations with depth or vanishing activations (rank collapse).

c.) Explicit normalization at every layer (layer norm, batch norm) to force activations to a 0-mean unit variance gaussian, which then shapes the gradient distribution.

d.) Gradient clipping to ensure that large updates do not happen which temporarily cuts off the effects of bad conditioning.

e.) Weight decay which keeps weights close to a 0-mean unit variance Gaussian and hence SGD in the expected range.

f.) Careful architecture design. We design neural networks to be as linear as possible while still having enough nonlinearities to represent useful functions. Innovations like residual connections allow for greater depth without exploding or vanishing as easily.

g.) The Adam optimizer performs a highly approximate (but cheap) rescaling of the parameters with their empirical variance, thus approximating true second order learning rate updates. This allows the parameters to individually get lower learning rates when transitioning through extremely high curvature regions and increased learning rate when going through low variance regions (endless nearly-flat plains).

Second, neural networks must be general enough to represent the kinds of nonlinear functions in the datasets, and ideally deep and wide enough to represent populations of parallel factored circuits for these functions. This property is crucial to both the modelling success of deep networks and ultimately their generalization ability. Deep neural networks fundamentally encode an extremely powerful prior and computation primitive: the re-usable/composable circuit with shared subcomputations. Essentially deep neural networks can build up representations serially and share these representations with later layers to form very large superpositions of circuits with shared subcomponents. Classical ways to think about this come from CNNs where this is very explicit. For instance, there are specific curve and line detectors, and then these shared subcomponents get built up into more complex representations such as circle or body detectors and so on. If we trace back an output circuit – for instance, one that detects a dog – we notice it shares large amounts of subcomponents with many other circuits – such as a cat detector, or a truck detector or whatever. This property is fundamental since it turns the otherwise exponential growth in potential subcircuits necessary to classify or model the data into a polynomial one through what is essentially dynamic programming.

This process requires both sufficient depth and sufficient width. If we have a shallow network, we cannot express serially dependent computations. The only option is to memorize a lookup table from input to output, which grows exponentially. Similarly, if we have only depth and no width, we can express serial computations, but not share any subcomponents. We would have to learn a separate ‘program’ for every possible input output pair, which is again exponential. If we have a network with depth D and width W, then, heuristically, we can represent  ‘circuits’ in parallel with only  compute! Moreover, during training we can implicitly optimize over an exponential  ‘circuit slice’ in parallel with a gradient update rather than having to perform some kind of ‘local’ program search as we would if we represented programs directly and used e.g. genetic algorithms. The big downside is that there is interference in the network between circuits – since circuits share subcomponents, if one circuit needs to change a subcomponent this can quickly prove detrimental to the function of other circuits in the model. Neural networks necessarily trade-off some modularity and interpretability for the ability to convert an exponential problem into a polynomial one. However, this problem is increasingly ameliorated in high dimensional spaces where large numbers of almost-orthogonal vectors can be stored with low distortion and interference.

Now that we understand the fundamental architectural constraints, we turn to the optimization. Beyond just getting SGD to make progress and not explode, there remains the question of why are neural networks optimizable at all? Why do neural networks not get stuck in the many extremely bad local minima that undoubtedly exist in the loss landscape? The answer to this is the blessings of dimensionality and noise.

Firstly, it appears that in large-scale neural network loss landscapes there are almost no local minima. In fact, almost all critical points are saddle points. This has been confirmed empirically in a number of different environments, but a-priori, should not have been surprising at all. A local minimum requires all of the eigenvalues of the Hessian to be positive. As the number of parameters increases, the probability that every last eigenvalue will be positive must decrease. An extremely simple bernoulli model with probability p of a positive eigenvalue and  for a negative eigenvalue would predict that the probability of a local minimum vs a saddle point is  where d is the number of parameters. This model predicts an exponential decrease in local minima with the number of parameters. In reality, this model is extremely oversimplified [1] and there are many other effects at play, but the basic logic holds. As parameter size increases, the number of eigenvalues that all must be positive increases and hence the total probability must decrease unless the correlation between eigenvalues is exactly 1.

Why is this crucial? Because saddle points are vastly easier to escape than local minima. Local minima are dynamically stable – if you perturb the parameters a bit, all gradients point back towards the minimum. This is due to the positive curvature. You are in a bowl with walls on every side. A saddle point by contrast is dynamically unstable. If you are perturbed in a negative curvature direction you are no longer in the basin of attraction of the saddle point. You are not in a bowl but have been tipped over an infinitesimal crest and are now rolling downhill with slowly gathering speed. What is more, perturbations are guaranteed to happen because of both finite learning rate effects meaning you never end up at exactly the local optimum, and because of the minibatch noise in stochastic gradient descent. Theoretically and empirically, a number of papers have indeed shown that saddle points do not pose a problem to optimization in practice and can be escaped easily.

Together, this means that as long as SGD is stable and can make progress, and the model is expressive enough to represent a good global optimum, then the optimization process will inexorably make progress towards and will ultimately converge to a region of very low loss around this global optimum. In the overparametrized case where the global optimum is an optimal manifold of solutions, the optimization process will converge to somewhere on this manifold.

This then raises the next question: why does this solution generalize? Specifically, in the overparametrized case there is an optimal manifold of optimal points for the training loss. Why, typically, do we find ones that generalize out of distribution. There are a number of answers to this – we use regularization such as weight decay to avoid overfitting, in many cases, we don't actually minimize training loss to 0, and that SGD naturally finds ‘flat minima’ which tend to generalize better than sharp minima. I want to zoom in one this one a little more. This claim of flat minima has some empirical backing behind it, and you can show that in some simple cases flat minima does provably have better generalization bounds in the case where the test and train set come from the same distribution and which can essentially be thought of as different noisy draws from approximately the same range – i.e. interpolation and not extrapolation.

However, typically people are at a loss as to why flat minima are preferred, and there is a lot of literature arguing that it is an ‘implicit bias of SGD’ to end up at flatter minima. I want to argue that this is not actually the main reason, and that in fact the fact that SGD finds flat minima is a very natural consequence of the geometry of the loss landscape which can be seen through a very simple argument.

Intuitively, the argument goes like this: flat minima should have much larger associated volumes of low-loss points surrounding them than sharp minima. This is simple. If there is low curvature in a direction, then it will stay ‘flat’ for longer. This means that for any radius away from the minimum in that direction, the loss is boundably lower than the case of a sharp minimum. Crucially, this effect becomes extremely large in high dimensional spaces since the volumes of shapes compared to their radii grow exponentially with dimension.

What this ends up meaning is that the flatter minima take up much more volume in parameter space than the less flat minima along some level set of loss difference . Naively, this means that if we initialize SGD randomly in parameter space, it is very much more likely to be closest to (and thus likely converge to) the flattest minima. Mathematically, assuming a positive definite Hessian (we are at a local minimum) and error tolerance of , we can say that the volume scales proportionary to the exponential of the dimension divided by the product of the square-rooted eigenvalues[2].

If we then take the ratio of two ellipses defined by different eigenvalues, all the constant factor cancel and we see it simply depends on the ratio of the two eigenvalue products,

If we make the highly artificial assumption that all the eigenvalues are the same in each ellipse we get the volume ratio exponential in the dimension More generally, if eigenvalues are sampled randomly, we should expect their ratio of products to follow a log-normal distribution which has extremely heavy tails. This means that in most cases, the bulk of the volume should be taken up by a few extremely flat outliers which extremely large volumes.

This result can be extended to give a new perspective on grokking. For simplicity let’s analyze SGD as stochastic langevin dynamics (i.e. gradient descent with isotropic Gaussian noise). In this case, it is known that langevin dynamics are an (exceptionally bad) MCMC algorithm which technically traverses the posterior distribution of the parameters given the data. Specifically, one can show fairly straightforwardly that under general conditions of smoothness and convergence that the stationary distribution of langevin dynamics has a probability of the parameters proportional to the negative exponential of the loss – i.e. a Boltzmann distribution on the loss values. This means, naturally, that SGD is exponentially likely to spend time in equilibrium in regions of lower loss than higher loss. If we think about the optimal loss manifold, then as long as this manifold is connected and the noise is isotropic (i.e. the langevin markov chain is ergodic) then the stationary distribution across the optimal manifold is uniform. In effect, once the model has converged to the optimal manifold, it slowly and uniformly diffuses across it. Crucially, since we know that the flattest minima will have exponentially larger volumes, this means that SGD at equilibrium will spend exponentially more time in these highly generalizing ‘grokking modes’ than in the sharp poorly generalizing modes. Thus, even if we train models which are unlucky and reach a poor minimum, if the model is overparametrized and we wait long enough, we will see a deterministic (if potentially extremely slow) convergence towards improved generalization.

To sum up, why does deep learning work?

1.) Deep neural networks are perfectly designed to be able to express and optimize over computational circuits with shared subcomponents. This both mirrors the structure of reality which we consider a sparse hierarchical factor graph, as well as enables the searching of exponential volumes of circuit space in only polynomial time.

2.) SGD works best with unit variance and we carefully design, initialize, and tune networks such that this condition is approximately fulfilled during optimization. This ensures that SGD can make meaningful progress through the landscape without exploding.

3.) Due to extremely high dimensionality, almost all critical points away from the optimum are saddle points which can be escaped easily.

4.) When SGD finds the basin near a global optimum, it is exponentially more likely to find a basin with low curvature and hence vastly greater volume.

5.) Low curvature volumes naturally generalize better due to greater robustness to noise in the data and maybe other reasons?

6.) Even if SGD does not find the largest flattest basin, given isotropic dataset noise, SGD is guaranteed to perform a random walk across the optimal manifold and hence spend exponentially more time in the generalizing flat regions (eventually) than in the non-generalizing sharp regions. This explains results like grokking[3].

  1. ^

    Other simple models such as using a Marchenko-Pastur eigenvalue distribution would also imply increasingly low probability of local minima with parameter dimensionality.

  2. ^

    An interesting and subtle complication is that the proportionality constant is which causes the actual volume as measured in terms of the unit hypercube to decrease to 0. Intuitively, this is because the unit hypercube increases in volume much faster than a sphere or ellipse inscribed inside the hypercube since in high dimension almost all volume is at the 'edges'. Mathematically this is the fact that the gamma function increases faster than exponentially.

  3. ^

    I presented an earlier version of this hypothesis here. This argument understood the importance of diffusion but didn’t answer why ‘anti-grokking’ does not occur whereby if the model can diffuse into the grokking region it also cannot diffuse out of it. The answer of course is that the grokking region is exponentially larger and hence it is exponentially unlikely to spontaneously diffuse out of the grokking region.

New Comment
4 comments, sorted by Click to highlight new comments since:

You may want to check out Singular Learning Theory which aims to provide a better explanation of the likelihood at ending up at different minima by studying singular minima and not just single-point minima.

Someone with better SLT knowledge might want to correct this, but more specifically:

Studying the "volume scaling" of near-min-loss parameters, as beren does here, is really core to SLT. The rate of change of this volume as you change your epsilon loss tolerance is called the "density of states" (DOS) function, and much of SLT basically boils down to an asymptotic analysis of this function. It also relates the terms in the asymptotic expansion to things you care about, like generalization performance.

You might wonder why SLT needs so much heavy machinery, since this sounds so simple - and it's basically because SLT can handle the case where the eigenvalues of the Hessian are zero, and the usual formula breaks down. This is actually important in practice, since IIRC real models often have around 90% zero eigenvalues in their Hessian. It also leads to substantially different theory - for instance the "effective number of parameters" (RLCT) can vary depending on the dataset.

Looks like I really need to study some SLT! I will say though that I haven't seen many cases in transformer language models where the eigenvalues of the Hessian are 90% zeros -- that seems extremely high.