I haven't walked through your math carefully, but I find this type of analysis interesting.
SGD is believed to have certain "bias" towards low-entropy models of the world. Part of this is a preference for "broader" rather than "narrower" minima in L. Now we have some tools which may allow us to understand this. Under this model, SGD is also biased towards regions of low variance in loss function.
This bias towards regions of low variance makes intuitive sense.
SGD's bias towards low-entropy models also has a simple explanation - good inits start it in a low entropy config, and SGD moves in an entropy efficient direction of maximizing loss decrease per unit weight change, which biases it strongly towards staying near the low entropy init. This becomes quite noticeable when you experiment with 2nd order optimizers which generally don't have this bias - they tend to overfit far more easily and need more explicit regularization.
My previous post about SGD was an intro to this model. That post concerned a model of a loss landscape on two "datapoints". In this post I attempt to build a new model of SGD and validate it, with mixed success, but it is sort of interesting.
Gradient Variance
We could model this another way. The expected change of W on each step is −TdWdL, but we will also expect variance. W will evolve over time through probability space. There are two competing "forces" here, the "spreading force" created by variance in in dWdlj over all datapoints in the model, and the "descent force" being exerted by gradient descent pushing W back into the centre of a given local minimum.
I think it makes sense to introduce some new notation here.
gj(W)=dljdW
G(W)=dLdW=average(all j)(gj(W))
S2(W)=variance(all j)(gj(W))
S(W)=+√S2(W)
The S2 notation should be thought of like the cos2(x) notation.
Plotting these for our current system:
Places where G is zero and the gradient of G is positive are the stable equilibrium points with regards to gradient descent on L (at ~1 and 2). If G and S2 are both zero at the same place, then this is an equilibrium point with regards to SGD on L (only at 2). The zero points for G and S2 around the pit at 1 are not quite in the same pace.
It is possible to consider probability mass of W "moving" according to the following rule:
A "point" (dirac δ distribution) of probability at W, between t and t+1, changes to a distribution centred at W−TG(W) with a variance of T2S2(W).
Now we have abstracted away T from the actual process of discontinuous updates, we can try and factor out the discontinuity entirely. This will make the maths more manageable when it comes to generalizing to larger models. T will likely be much smaller for larger models but as long as S(W) grows larger with the number of datapoints used, this will compensate.
(Point of notation, I will be using d rather than ∂, even though the latter is arguably more correct. As we will never be "mixing" Wand t it won't make a difference to our results)
Instead of probability distribution moving, we might now consider it flowing. This can be described by a probability current density ρ:
Consider a system with S2(W)=0 everywhere. The probability will just flow down the gradient:
ρ(W,t)G=−TG(W)P(W,t)
Taking dPdt=−dρdW we get (when dependencies are removed for ease of reading):
dPdtG=T[GdPdW+PdGdW]
Now consider a system with G(W)=0 everywhere. Now we effectively have the evolution of a probability distribution via random walk. This gives a "spreading out" effect. With constant S2 we have the following equation for ρ, borrowed from the heat equation. I will take the central limit theorem and assume that the gradients are normally distributed.
ρ(W,t)S=−12T2S2dP(W,t)dW
Based on the fundamental solution of the heat equation this will increase our variance by T2S2 each step of t.
Which gives us:
dPdWS=12T2S2d2PdW+12T2d(S2)dW2dPdW
But the speed of "spreading out" is proportional to S(W) which changes the equation. The slower the "spreading out", the higher the probability of W being there. This makes S(W) act like a "heat capacity" of the location for P(W,t) for which ρ is a conserved current. We might be able to borrow more from heat equations. In this case P(W,t)S(W) acts as the "temperature" of a region.
ρ(W,t)S=−k(W)d[P(W,t)S(W)]dW
ρ(W,t)S=−k(W)[S(W)dP(W,t)dW+P(W,t)dS(W)dW]
Calculating k based on our previous equation gives k(W)=12T2S(W), which gives:
dρdWS=−ddW[−12T2S2(W)dP(W,t)dW−12T2S(W)P(W,t)dS(W)dW]
This can be reduced to the rather unwieldy equation (removing function dependencies for clarity):
dPdtS=T2[32SdSdWdPdW+12S2d2PdW2+12(dSdW)2P+12SPd2SdW2]
But these can be expressed in terms of S2 rather than S, which is good when S is pathological in some way (like when S2 is zero above, S has a discontinuous derivative). It also makes sense that our equation shouldn't depend on our choosing positive rather than negative S.
ρS=−T2[12S2dPdW+14Pd(S2)dW]
dPdTS=T2[34d(S2)dWdPdW+12S2dP2dW+14d2(S2)dW2P]
Finally giving our master equations:
ρ=−T2[12S2dPdW+14d(S2)dWP]−T[GP]
dPdt=T2[34d(S2)dWdPdW+12S2dP2dW+14d2(S2)dW2P]+T[GdPdW+PdGdW]
Validation of the First Term of the Equations
Let's start with the first equation, and simulate using our G function from before.
T = 0.02, no stochasticity yet.
Here's W on the y-axis, and t on the x-axis. This is what the evolution of W looks like for a series of initial W values:
Now let's pick a couple of initial distributions and see how they evolve over time:
Time evolution with steps of Δt=10:
This looks about right!
Now let's plot the mean of this over time, and compare to the mean and standard deviation of a Monte Carlo simulation of gradient descent. The Monte Carlo simulation starts with 1000 W values chosen to form a normal distribution with the roughly same mean and standard deviation (0.5 and 0.175 respectively) as our initial P(W,t) distribution.
Our first equation is an accurate description of non-stochastic gradient descent. The rest of the difference in the standard deviation is most likely due to imperfect matching of our initial data (P is a truncated normal distribution but our Monte Carlo uses a normal distribution with matched mean and variance to the truncated P, so some elements are <0 where the gradient is small).
Validation of the Second Term of the Equations
Let's take our first example as a distribution spreading out.
g0=−1,g1=1,G=0,S2=1
And compare standard deviations to our Monte Carlo simulation:
Looking good, errors here may also be due to truncation.
One final validation step: take g0=1.2W−2, g1=0.8W−2, G=W−2, S2=0.4W2, T=0.5. This model will be used to assess a few things: our ability to perform well at higher T, its ability to predict the correct form of the counterbalancing "concentrating" and "spreading" forces of G and S2, and its ability to predict the concentration of probability mass in regions of lower S2.
Unfortunately the computational modelling seems to fall apart when applied to the original system. The large first and second derivatives of S2 lead to a lot of instability. This means I can't validate it much more than this. High values of T also cause the model to break down, as the gradient might change a lot in the span of a step. I think this can be remedied by (for example) picking a g to update on and updating with multiple small steps before changing g.
I'm no master programmer and I don't have much experience working with unstable PDEs. So I can't do much more here.
Solving for End-States
For an end-state, ρ=0 everywhere. This means:
0=−T2[12S2dPdW+14d(S2)dWP]−T[GP]
T12S2dPdW=−T14d(S2)dWP−GP
dPdW1P=−12S2d(S2)dW−2GTS2
d(log(P))dW=−12S2d(S2)dW−2GTS2
d(log(P))dW=−12d(log(S2))dW−2GTS2
This shows our problem. When S2 vanishes, our equations don't work terribly well. We might have to hope that the two opposing S2 terms cancel out and it works, but who knows. This is probably the source of instability in our equations.
But around some minimum it lets us interpret something. If log(P) is decreasing linearly then P decreases exponentially. Let's consider the 2GTS2 term now. If we have two minima (with a maximum between them) around which the loss landscapes are exactly the same, except one is twice as wide (in all li) then the G component will be halved in the wider one, but the S2 part will be quartered. This means the integral of ∫−2GTS2dW from the centre of the wider one to the maximum will be four times that of the narrower one. Therefore the probability density at the centre of the wider minimum's basin will be
e^4 = 56 timesEdit: a lot higher.What's the point?
Reasoning about stochastic processes is difficult. Reasoning about differential equations is also difficult, but the tools to analyse differential equations are different and might be able to solve different problems.
SGD is believed to have certain "bias" towards low-entropy models of the world. Part of this is a preference for "broader" rather than "narrower" minima in L. Now we have some tools which may allow us to understand this. Under this model, SGD is also biased towards regions of low variance in loss function.
Further Investigation
I think there's something like a metric acting on a space here. S2 looks like a metric, and perhaps it's actually more correct to consider the space of W with the metric such that S2=1 everywhere. For higher dimensions we get the following transformations:
W→−→W
G→→G
S2→S2
Now −→W and →G are vectors and S2 is a matrix. This extends nicely as we can choose our metric such that S2=I. It might be useful to define some sort of function like an "energy" over the landscape of \(\\)W in terms of G, S, and T alone which describes the final probability distribution. In fact such a function must exist assuming SGD converges, as log(P(W,∞)) is well-defined. What the actual form of this function is would require to do some working out, and it may not be at all easily described. This whole process is very reminiscent of both chemical dynamical modelling and finding the minimum-energy configuration of a quantum energy landscape, as both consist of a "spreading" term and an "energy" term.
While it is quite interesting, I don't consider this a research priority for myself. About 90% of this post has been sitting in my drafts for the past 3 months. Even if powerful AI is created using SGD, I'm not convinced that this sort of model will be hugely useful. It might be possible to wrangle some selection-theorem-ish-thing out of this but I don't think I'll focus on it.