In this note I will discuss some computations and observations that I have seen in other posts about "basin broadness/flatness". I am mostly working off the content of the posts Information Loss --> Basin flatness and Basin broadness depends on the size and number of orthogonal features. I will attempt to give one rigorous and unified narrative for core mathematical parts of these posts and I will also attempt to explain my reservations about some aspects of these approaches. This post started out as a series of comments that I had already made on the posts, but I felt it may be worthwhile for me to spell out my position and give my own explanations.
Work completed while author was a SERI MATS scholar under the mentorship of Evan Hubinger.
Basic Notation and Terminology
We will imagine fixing some model architecture and thinking about the loss landscape from a purely mathematical perspective. We will not concern ourselves with the realities of training.
Let denote the parameter space of a deep neural network model . This means that each element is a complete set of weights and biases for the model. And suppose that when a set of parameters is fixed, the network maps from an input space to an output space . When it matters below, we will take , but for now let us leave it abstract. So we have a function
such that for any , the function is a fixed input-output function implemented by the network.
Let be a dataset of training examples. We can then define a function , by
This takes as input a set of parameters and returns the behaviour of on the training data.
We will think of the loss function as .
Example. We could have , , and
We also then define what we will call the total loss
This is just the usual thing: The total loss over the training data set for a given set of weights and biases. So the graph of is what one might call the 'loss landscape'.
By a behaviour manifold (see [Hebbar]), we mean a set of the form
where is a tuple of possible outputs. The idea here is that for a fixed behaviour manifold , all of the models given by parameter sets have identical behaviour on the training data.
Assume that is an appropriately smooth -dimensional space and let us now assume that .
Suppose that . In this case, at a point at which the Jacobian matrix has full rank, the map is a submersion. The submersion theorem (which - in this context - is little more than the implicit function theorem) tells us that given , if is a submersion in a neighbourhood of a point , then is a smooth -dimensional submanifold in a neighbourhood of . So we conclude that in a neighbourhood of a point in parameter space at which the Jacobian of has full rank, the behaviour manifold is an -dimensional smooth submanifold.
Firstly, I want to emphasize that when the Jacobian of does not have full rank, it is generally difficult to make conclusions about the geometry of the level set, i.e. about the set that is called the behaviour manifold in this setting.
Examples. The following simple examples are to emphasize that there is not a straightforward intuitive relationship that says "when the Jacobian has less than full rank, there are fewer directions in parameter space along which the behaviour changes and therefore the behaviour manifold is bigger than -dimensional":
- Consider given by . We have . This has rank 1 everywhere except the origin: At the point it has less than full rank. And at that point, the level set is just a single point, i.e. it is 0-dimensional.
- Consider given by We have . Again, this has less than full rank at the point And at that point, the level set is the entire -axis, i.e. it is 1-dimensional.
- Consider given by We of course have . This has less than full rank everywhere, and the only non-empty level set is the entire of , i.e. 2-dimensional.
Remark. We note further, just for the sake of intuition about these kinds of issues, that the geometry of the level set of a smooth function can in general be very bad: Every closed subset is the zero set of some smooth function, i.e. given any closed set , there exists a smooth function with Knowing that a level set is closed is an extremely basic fact and yet without using specific information about the function you are looking at, you cannot conclude anything else.
Secondly, the use of the submersion theorem here only makes sense when . But this is not even commonly the case. It is common to have many more data points (the ) than parameters (the ), ultimately meaning that the dimension of is much, much larger than the dimension of the domain of . This suggests a slightly different perspective, which I briefly outline next.
When the codomain is a higher-dimensional space than the domain, we more commonly picture the image of a function, as opposed to the graph, e.g. if I say to consider a smooth function , one more naturally pictures the curve in the plane, as a kind-of 'copy' of the line , as opposed to the graph of . So if one were to try to continue along these lines, one might instead imagine the image of parameter space in the behaviour space We think of each point of as a complete specification of possible outputs on the dataset. Then the image is (loosely speaking) an dimensional submanifold of this space which we should think of as having large codimension. And each point on this submanifold is the outputs of an actual model with parameters . In this setting, the points at which the Jacobian has full rank map to points which have neighbourhoods in which is smooth and embedded.
The Hessian of the Total Loss
A computation of the Hessian of appears in both Information Loss --> Basin flatness and Basin broadness depends on the size and number of orthogonal features, under slightly different assumptions. Let us carefully go over that computation here, in a slightly greater level of generality. We continue with , in which case . The function we are going to differentiate is:
And since each for , we should think of as a matrix, the general entry of which is .
We want to differentiate twice with respect to . Firstly, we have
Then for we differentiate again:
This is now an equation of matrices.
At A Local Minimum of The Loss Function
If is such that is a local minimum for (which means that the parameters are such that the output of the network on the training data is a local minimum for the loss function), then the second term on the right-hand side of (1) vanishes (because the term includes the first derivatives of , which are zero at a minimum). Therefore: If is a local minimum for we have:
If, in addition, the Hessian of is equal to the identity matrix (by which we mean - as is the case for the example loss function given above in (*)), then we would have:
In Basin broadness depends on the size and number of orthogonal features, the expression on the right-hand side of equation (2) above is referred to as an inner product of "the features over the training data set". I do not understand the use of the word 'features' here and in the remainder of their post. The phrase seems to imply that a function of the form
defined on the inputs of the training dataset, is what constitutes a feature. No further explanation is really given. It's completely plausible that I have missed something (and perhaps other readers do not or will not share my confusion) but I would like to see an attempt at a clear and detailed explanation of exactly how this notion is supposed to be the same notion of feature that (say) Anthropic use in their interpretability work (as was claimed to me).
I'd like to tentatively try to give some higher-level criticism of these kinds of approaches. This is a tricky thing to do, I admit; it's generally very hard to say that a certain approach is unlikely to yield results, but I will at least try to explain where my skepticism is coming from.
The perspective and the computations that are presented here (which in my opinion are representative of the mathematical parts of the linked posts and of various other unnamed posts) do not use any significant facts about neural networks or their architecture. In particular, in the mathematical framework that is set up, the function is more or less just any smooth function. And the methods used are just a few lines of calculus and linear algebra applied to abstract smooth functions. If these are the principal ingredients, then I am naturally led to expect that the conclusions will be relatively straightforward facts that will hold for more or less any smooth function .
Such facts may be useful as part of bigger arguments - of course many arguments in mathematics do yield truly significant results using only 'low-level' methods - but in my experience one is extremely unlikely to end up with significant results in this way without it ultimately being clear after the fact where the hard work has happened or what the significant original insight was.
So, naively, my expectation at the moment is that in order to arrive at better results about this sort of thing, arguments that start like these ones do must quickly bring to bear substantial mathematical facts about the network, e.g. random initialization, gradient descent, the structure of the network's layers, activations etc. One has to actually use something. I feel (again, speaking naively) that after achieving more success with a mathematical argument along these lines, one's hands would look dirtier. In particular, for what it's worth, I do not expect my suggestion to look at the image of the parameter space in 'behaviour space' to lead (by itself) to any further non-trivial progress. (And I say 'naively' in the preceding sentences here because I do not claim myself to have produced any significant results of the form I am discussing).