Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

In this note I will discuss some computations and observations that I have seen in other posts about "basin broadness/flatness". I am mostly working off the content of the posts Information Loss --> Basin flatness and Basin broadness depends on the size and number of orthogonal features. I will attempt to give one rigorous and unified narrative for core mathematical parts of these posts and I will also attempt to explain my reservations about some aspects of these approaches. This post started out as a series of comments that I had already made on the posts, but I felt it may be worthwhile for me to spell out my position and give my own explanations. 


Work completed while author was a SERI MATS scholar under the mentorship of Evan Hubinger.

Basic Notation and Terminology

We will imagine fixing some model architecture and thinking about the loss landscape from a purely mathematical perspective. We will not concern ourselves with the realities of training.

Let  denote the parameter space of a deep neural network model . This means that each element  is a complete set of weights and biases for the model. And suppose that when a set of parameters  is fixed, the network maps from an input space  to an output space . When it matters below, we will take , but for now let us leave it abstract. So we have a function

such that for any , the function  is a fixed input-output function implemented by the network.

Let  be a dataset of  training examples.  We can then define a function , by 

This takes as input a set of parameters  and returns the behaviour of  on the training data.

We will think of the loss function as 

Example. We could have  ,  , and

We also then define what we will call the total loss

by 

This is just the usual thing: The total loss over the training data set for a given set of weights and biases. So the graph of  is what one might call the 'loss landscape'.
 

Behaviour Manifolds

By a behaviour manifold (see [Hebbar]), we mean a set  of the form 

where  is a tuple of possible outputs. The idea here is that for a fixed behaviour manifold , all of the models given by parameter sets  have identical behaviour on the training data. 

Assume that  is an appropriately smooth -dimensional space and let us now assume that .

Suppose that . In this case, at a point  at which the Jacobian matrix  has full rank, the map  is a submersion. The submersion theorem (which - in this context - is little more than the implicit function theorem) tells us that given , if  is a submersion in a neighbourhood of a point , then  is a smooth -dimensional submanifold in a neighbourhood of  . So we conclude that in a neighbourhood of a point in parameter space at which the Jacobian of  has full rank, the behaviour manifold is an -dimensional smooth submanifold. 

Reservations

Firstly, I want to emphasize that when the Jacobian of  does not have full rank, it is generally difficult to make conclusions about the geometry of the level set, i.e. about the set that is called the behaviour manifold in this setting.  

Examples. The following simple examples are to emphasize that there is not a straightforward intuitive relationship that says "when the Jacobian has less than full rank, there are fewer directions in parameter space along which the behaviour changes and therefore the behaviour manifold is bigger than -dimensional":

  1.  Consider  given by . We have . This has rank 1 everywhere except the origin: At the point  it has less than full rank. And at that point, the level set is just a single point, i.e. it is 0-dimensional.
  2. Consider  given by  We have . Again, this has less than full rank at the point  And at that point, the level set is the entire -axis, i.e. it is 1-dimensional.
  3. Consider  given by  We of course have . This has less than full rank everywhere, and the only non-empty level set is the entire of , i.e. 2-dimensional.

Remark. We note further, just for the sake of intuition about these kinds of issues, that the geometry of the level set of a smooth function can in general be very bad: Every closed subset is the zero set of some smooth function, i.e. given any closed set  , there exists a smooth function  with  Knowing that a level set is closed is an extremely basic fact and yet without using specific information about the function you are looking at, you cannot conclude anything else. 

Secondly, the use of the submersion theorem here only makes sense when . But this is not even commonly the case. It is common to have many more data points (the ) than parameters (the ), ultimately meaning that the dimension of  is much, much larger than the dimension of the domain of . This suggests a slightly different perspective, which I briefly outline next.

Behavioural Space

When the codomain is a higher-dimensional space than the domain, we more commonly picture the image of a function, as opposed to the graph, e.g. if I say to consider a smooth function , one more naturally pictures the curve  in the plane, as a kind-of 'copy' of the line , as opposed to the graph of . So if one were to try to continue along these lines, one might instead imagine the image  of parameter space in the behaviour space  We think of each point of  as a complete specification of possible outputs on the dataset. Then the image  is (loosely speaking) an dimensional submanifold of this space which we should think of as having large codimension. And each point  on this submanifold is the outputs of an actual model with parameters . In this setting, the points  at which the Jacobian  has full rank map to points  which have neighbourhoods in which  is smooth and embedded.
 

The Hessian of the Total Loss

A computation of the Hessian of  appears in both Information Loss --> Basin flatness and  Basin broadness depends on the size and number of orthogonal features, under slightly different assumptions. Let us carefully go over that computation here, in a slightly greater level of generality. We continue with , in which case .  The function we are going to differentiate is:

And since each  for , we should think of  as a  matrix, the general entry of which is .

We want to differentiate twice with respect to . Firstly, we have 

for 

Then for  we differentiate again:

This is now an equation of  matrices.

At A Local Minimum of The Loss Function

If  is such that  is a local minimum for  (which means that the parameters are such that the output of the network on the training data is a local minimum for the loss function), then the second term on the right-hand side of (1) vanishes (because the term includes the first derivatives of , which are zero at a minimum). Therefore: If  is a local minimum for  we have: 

If, in addition, the Hessian of  is equal to the identity matrix (by which we mean  - as is the case for the example loss function given above in (*)), then we would have:

 

Reservations

In Basin broadness depends on the size and number of orthogonal features, the expression on the right-hand side of equation (2) above is referred to as an inner product of "the features over the training data set". I do not understand the use of the word 'features' here and in the remainder of their post. The phrase seems to imply that a function of the form

defined on the inputs of the training dataset, is what constitutes a feature. No further explanation is really given. It's completely plausible that I have missed something (and perhaps other readers do not or will not share my confusion) but I would like to see an attempt at a clear and detailed explanation of exactly how this notion is supposed to be the same notion of feature that (say) Anthropic use in their interpretability work (as was claimed to me).


Criticism

I'd like to tentatively try to give some higher-level criticism of these kinds of approaches. This is a tricky thing to do, I admit; it's generally very hard to say that a certain approach is unlikely to yield results, but I will at least try to explain where my skepticism is coming from. 

The perspective and the computations that are presented here (which in my opinion are representative of the mathematical parts of the linked posts and of various other unnamed posts) do not use any significant facts about neural networks or their architecture. In particular, in the mathematical framework that is set up, the function  is more or less just any smooth function. And the methods used are just a few lines of calculus and linear algebra applied to abstract smooth functions. If these are the principal ingredients, then I am naturally led to expect that the conclusions will be relatively straightforward facts that will hold for more or less any smooth function .

Such facts may be useful as part of bigger arguments - of course many arguments in mathematics do yield truly significant results using only 'low-level' methods - but in my experience one is extremely unlikely to end up with significant results in this way without it ultimately being clear after the fact where the hard work has happened or what the significant original insight was. 

So, naively, my expectation at the moment is that in order to arrive at better results about this sort of thing, arguments that start like these ones do must quickly bring to bear substantial mathematical facts about the network, e.g. random initialization, gradient descent, the structure of the network's layers, activations etc. One has to actually use something. I feel (again, speaking naively) that after achieving more success with a mathematical argument along these lines, one's hands would look dirtier. In particular, for what it's worth, I do not expect my suggestion to look at the image of the parameter space in 'behaviour space' to lead (by itself) to any further non-trivial progress. (And I say 'naively' in the preceding sentences here because I do not claim myself to have produced any significant results of the form I am discussing).


35

Ω 23

New Comment
4 comments, sorted by Click to highlight new comments since: Today at 11:49 PM

I worry that using as the space of behaviors misses something important about the intuitive idea of robustness, making any conclusions about or or behavior manifolds harder to apply. A more natural space (to illustrate my point, not as something helpful for this post) would be , with a metric that cares about how outputs differ on inputs that fall within a particular base distribution , something like

The issue with is that models in a behavior manifold only need to agree on the training inputs, and always include all models with arbitrarily crazy behaviors at all inputs outside the dataset, even if we are talking about inputs very close to those in the dataset (which is what above is supposed to prevent). So the behavior manifolds are more like cylinders than balls, ignoring crucial dimensions. Since generalization does work (so learning tends to find very unusual points of them), it's generally unclear how a behavior manifold as a whole is going to be relevant to what's actually going on.

I agree that the space  may well miss important concepts and perspectives. As I say, it is not my suggestion to look at it, but rather just something that was implicitly being done in another post. The space  may well be a more natural one. (It's of course the space of functions , and so a space in which 'model space' naturally sits in some sense. )

The perspective and the computations that are presented here (which in my opinion are representative of the mathematical parts of the linked posts and of various other unnamed posts) do not use any significant facts about neural networks or their architecture.

You're correct that the written portion of the Information Loss --> Basin flatness post doesn't use any non-trivial facts about NNs.  The purpose of the written portion was to explain some mathematical groundwork, which is then used for the non-trivial claim.  (I did not know at the time that there was a standard name "Submersion theorem".  I had also made formal mistakes, which I am glad you pointed out in your comments.  The essence was mostly valid though.)  The non-trivial claim occurs in the video section of the post, where a sort of degeneracy occuring in ReLU MLPs is examined.  I now no longer believe that the precise form of my claim is relevant to practical networks.  An approximate form (where low rank is replaced with something similar to low determinant) seems salvageable, though still of dubious value, since I think I have better framings now.

Secondly, the use of the submersion theorem here only makes sense when .

Agreed.  I was addressing the overparameterized case, not the underparameterized one.  In hindsight, I should have mentioned this at the very beginning of the post -- my bad.

(Sorry for the very late response)

All in all, I don't think my original post held up well.  I guess I was excited to pump out the concept quickly, before the dust settled.  Maybe this was a mistake?  Usually I make the ~opposite error of never getting around to posting things.