The idea of 'simplicity bias' is a popular one in informal discussions about the functions that neural networks implement. I recently tried to think a little bit more about what this meant. This brief note was written mostly for my own benefit, but could be of benefit to others. Having said that, it is not unlikely that I have made significant errors, and if so then I welcome all good faith attempts to point them out.
It seems like it is not uncommon for a surface-level or informal understanding of the idea to have come from Chirs Mingard's medium posts Neural networks are fundamentally (almost) Bayesian and Deep Neural Networks are biased, at initialisation, towards simple functions. Mingard attempts to summarize the results from a handful of papers on the topic, two of the most important of which he first-authored. For what it's worth (and in case this post seems like only criticism), I am of the opinion that the work is generally of a high quality and I do not think that Mingard is attempting to misrepresent his work. However, I do think - whether it is Mingard or his readers that are most responsible - that in attempting to generate bitesized, quotable 'takeaways' from technical research one can sometimes be left with decontextualized nuggets of information that are liable to be misunderstood.
Fix a deep neural network architecture and let denote the parameter space, by which mean that each point is a complete set of weights and biases for the network. The architecture defines a map between and the space of functions that the architecture is capable of implementing, i.e.
and each is a specific input-output function that the network is capable of implementing. We will denote the complexity of by , so that
(and I will avoid getting into a discussion of different notions or different measures of complexity).
Given a probability measure on , we can think of picking at random according to the measure - this is random initialization of the network. We can also think of as a random variable and indeed . Then, the cleanest way to explain what is meant by 'simplicity bias' is the following statement:
i.e. If the complexity of is less than the complexity of , then the probability of being selected by random initialization is greater than the probability of being selected. (Mingard does say this: "We call this a simplicity bias — because P(f) is higher for simple functions.")
I want to refer to statements like (1) as 'pointwise' simplicity bias. The crucial thing to note is that pointwise statements of the form (1) are not the same as claims about the distribution of , i.e. I cannot easily turn (1) into a statement about , i.e. about the probability of the complexity of the selected function being equal to (or indeed a statement about or etc.). Intuitively, the reason why this is the case is that although a given, specific, low-complexity function has a higher probability of being selected than a given, specific, high-complexity function, it may well be the case that there are vastly more high-complexity functions than low complexity ones and so it may in turn be the case that the probability of ending up with a high-complexity function is still greater than or at least comparable to the probability of ending up with a low complexity function.
Notice moreover that in many contexts is not unreasonable to expect many more high-complexity objects than low-complexity objects, simply because of the nature of measures of complexity that are based on or correlate with entropy. A classic straightforward example is to consider binary strings of length : There are relatively few low entropy strings - those with a very small or with a very large number of ones - but there are exponentially many high-entropy strings, e.g. there are around strings with ones. So, even under a probability measure in which the probability of selecting a string is bounded above by something that decays quickly with the entropy of the string, it is not clear whether or not you should expect a low entropy outcome.
Remark. If you want to think at this level about whether a deep neural network is (something like) 'biased towards simplicity', then you may be interested in some slightly different questions. If we were to take inspiration from what we would normally mean by 'bias' in other contexts in statistics and machine learning, then it seems that really one would like to understand things like and . It's not clear to me exactly what relationship between these quantities or what kind of statement actually best represents the informal notion of being 'biased towards simplicity'. (And nor do I think that we need to bother starting with a vague phrase that we want to be true and then formalizing it; we can at least try just asking more precise questions upfront in this case.)
Bias Towards Simplicity
Now fix some training data and let denote the subset of functions that have zero loss on this data. We call such functions interpolants. So is the region of parameter space that corresponds to functions with zero loss. We can suppose (as Mingard et al. essentially do, though my setup and notation is slightly different) that there is another probability measure on for which is the probability that the function is found as the result of running an ML optimization algorithm (like SGD) that results in the achievement of zero loss on the training data. The main empirical result of Mingard et al. is that for we have
Somewhat informally, this says that the probability of the optimization algorithm finding a given interpolant is approximately equal to the probability of that interpolant being found by randomly choosing parameters from .
In addition, there is an idea that deep neural networks trained at supervised learning tasks on real data tend to find simpler - rather than more complex - interpolants for the training data. Note that this is supposed to be a different statement from the kind of 'pointwise' simplicity bias discussed above, i.e. here I am really saying: We expect to see simpler learned interpolants.
With this in mind, one can produce an argument for the existence of a bias towards simplicity which relies on this last point. Here is an outline:
- Assumption 1: The optimization algorithm (such as SGD) that finds an interpolant is biased towards simplicity.
- Assumption 2: The probability of the optimization algorithm finding any given interpolant is approximately equal to the probability of that interpolant being found by randomly choosing parameters from .
- Conclusion: Therefore the process of randomly choosing parameters from is biased towards simpler functions.
What I have called 'Assumption' 2 is that for which lots of empirical evidence was provided by Mingard et al. If the conclusion were true then, more so than any pointwise simplicity bias statement, it does suggest to me that deep neural network architectures harbour a bias towards simplicity. Looking at the argument as a whole, one way to phrase a correct conclusion is as follows: To the extent that you believe that deep neural networks tend to find simpler zero-loss functions in practice (when more complex ones would also work), you should attribute that effect to some intrinsic quality of architecture and parameter space, rather than to something about the optimization algorithm.
Remark. What of the remaining tricky question about whether or not Assumption 1 is really true, i.e. to what extent do we believe that deep neural networks tend to find simpler functions in practice? Currently I don't have anything useful to add.