Complexity Penalties in Statistical Learning

by michael_h 7mo6th Feb 20193 comments


I am currently taking a course on statistical learning at the Australian Mathematical Sciences Institute Summer School. One idea that has appeared many times in the course is that a more complicated model is likely to have many short comings. This is because complicated models tend to overfit the observed data. They often give explanatory value to parts of the observation that are simply random noise.

This is common knowledge for many aspiring rationalists. The term complexity penalty is used to describe the act of putting less credence in complicated explanations because they are more complex. In this blog post I aim to provide a brief introduction to statistical learning and use an example to demonstrate how complexity penalties arise in this setting.

Statistical Learning

Broadly speaking, statistical learning is the process of using data to select a model and then using the model to make predictions about future data. So, in order to perform statistical learning, we need at least three things. We need some data, a class of models and a way of measuring how well a model predicts the future data. In this blog we will look at the problem of polynomial regression.

The Data

For polynomial regression, our data is in the form of pairs of real numbers. Our goal is to find the relationship between the input values and the output values and then use this to predict future outputs given new inputs. For example, the input values could represent the average temperature of a particular day and the corresponding output value could be the number of ice creams sold that day. Going with this example, we can suppose our data looks something like this:

To simplify our analysis we will make some assumptions about the relationship between our the inputs and outputs. We will assume that there exists an unknown function such that , where is a statistical error term with mean equal to 0 and variance equal to . This assumption is essentially saying that there is some true relationship between our input and our output but that the output can fluctuate around this true value. Furthermore we are assuming that these fluctuations are balanced in the positive and negative direction (since the mean of is zero) and the size of these fluctuations doesn't depend on the input (since the variance of is constant).

The Models

We want models that can take in a new number and predict what the corresponding should be. Thus our models will be functions that take in real numbers and return real numbers. Since we are doing polynomial regression, the classes of models we will be using will be different sets of polynomial functions. More specifically, let be the set of polynomials of degree at most . That is contains all functions of the form

The parameter corresponds to the complexity of the class of model we are using. As we increase , we are considering more and more complicated possible models.

Evaluating the models

We now have our data and our class of models, the remaining ingredient is a way to measure the performance of a particular model. Recall that our goal is to find a model that can take in new numbers and predict which should be associated to it. Thus if we have a second set of data , one way to measure the performance is to look at the average distance between our guess and the actual value . That is, the best model is the one which minimizes

It turns out that looking at the average squared distance between our guesses and the actual value gives a better way to measure performance. By taking squares we are more forgiving when the model gets the answer almost right but much less forgiving when the model is way off. Taking squares also makes the mathematics more tractable. The best model now becomes the one which minimizes

The above average is called the test loss of the model . From our assumptions about the type of data we're modeling we know that even the perfect function will occasionally differ from the output we're given. Thus, most of the time, we won't be able to make the test loss much smaller than which is the expected test loss of .

Using the test loss to measure performance has one clear limitation, it requires a second batch of data to test our models. What do we do if we only have one batch? One solution is to divide our batch in two and keep some data to the side to use to test models. Another solution is to try to estimate the test loss. It turns out that complexity penalties naturally arise when exploring this second solution.

Training loss

One way to try to estimate the test loss is to look at how well our model matches the data we've seen so far. This gives rise to the training loss which is defined as

Note that for the training loss we're using the original data points to test the performance of our model. This makes the training loss easy to calculate and easy to minimize within the class (the set of all polynomials of degree at most ). Here is a plot of some of the polynomials of a fixed degree that minimize the training loss. The purple polynomial has degree 1, the green polynomial has degree 2 and the black polynomial has degree 15.

Since the training loss only uses the old data it doesn't tell us much about how the model will perform on new data. For example, while the 15 degree polynomial matches the above initial data very well, it is overfitting. The 15 degree polynomial does a poor job of matching some new independent data, as shown below.

In general, we'd expect the training loss to be much smaller than the testing loss. This is because the model has already been calibrated to the original data. Indeed if we were using polynomials of degree we would be able to find a model that passes through every data point . Such a model would have a training loss of 0 but wouldn't generalize well to new data and would have a high test loss.

Approximating the test loss

Thus it seems that the training loss won't be the most informative or useful ways of estimating the testing loss. However, the training loss is salvageable, we just need to add an extra term that makes up for how optimistic the training loss is. Note that we can write the test loss as

Thus the difference between test loss and the training loss of a model gives us a way of quantifying how much the model is over-fitting the training data. Thus if we can estimate this difference we'll be able to add it to the training loss to get an estimate of the test loss and evaluate the performance of our model!

It turns out that in our particular case estimating this difference isn't too tricky. Suppose that we have a model in the class (that is is a polynomial of degree at most ). Suppose further that we have a lot of data points (in particular assume that , the number of data points, is greater than ). Then, under the above assumptions, we have the following approximation

Rewriting this we have

Thus we can measure the performance of our models by calculating the training loss and then adding . This number is our complexity penalty as it increases as the complexity parameter increases. It also increases with , the variance of the errors. Thus the more noisy our data is, the more likely we are to overfit. Also the penalty decreases with , the number of data points. This suggests that if we have enough data we can get away with using quite complicated models without worrying about overfitting. This is because with enough data the true relationship between the inputs and outputs will become very clear and even a quite complicated model mightn't overfit it.

One last interesting observation about this complexity penalty is the way it depends on a given model. Recall that a model is a polynomial of degree and that and are parameters for the whole statistical learning problem. Thus the above complexity penalty depends on only via the degree of . This gives us the following tractable way of finding the best model. For each we can find the polynomial of degree that minimizes the training loss and record the training loss it achieves. We can then compare polynomials of different degrees by adding the complexity penalty to the training loss. We can then chose the best model based off which minimizes the sum of the training loss and complexity penalty. The only downside to this method is that is an unknown quantity but hopefully some heuristics can be used to estimate it.

Below is a plot of the training loss, test loss and the approximation of the test loss from our example for different values of . While the approximation isn't always exact it follows the general trend of the test loss. Most importantly, both the test loss and the estimation have a minimum at . This shows that using approximation would let us select the best model which in this case is a quadratic.

Other examples of complexity penalties

Complexity penalties can be found all over statistical learning. In other problems the above estimate can be harder to calculate. Thus complexity penalties are used in a more heuristic manner. This gives rise to techniques such as ridge regression, LASSO regression and kernel methods. Model complexity is again an important factor when training neural networks . The number of layers and the size of each layer are both complexity parameters and must be tuned to avoid overfitting.

What makes the above example interesting is that the complexity penalty arose naturally out of trying to measure the performance of our model. It wasn't a heuristic but rather a proven formula guaranteed to provide a good estimate of the test loss. This in turn gives support to the heuristic complexity penalties used in situations when such proofs or formulas are more difficult to come by.


The ideas in this blog post are not my own and come from the AMSI Summer School course Mathematical Methods for Machine Learning taught by Zdravko Botev. The notes for the course will soon be published as a book by D. P. Kroese, Z. I. Botev, S. Vaisman and T. Taimre titled "Mathematical and Statistical Methods for Data Science and Machine Learning".

I made the plots myself but they are based off similar plots from the course. You can access the data set I made and used to make the plots here.