Laplace Approximation

by johnswentworth 3mo18th Jul 20193 comments

27


The last couple posts compared some specific models for 20000 rolls of a die. This post will step back, and talk about more general theory for Bayesian model comparison.

The main problem is to calculate for some model. The model will typically give the probability of observed data (e.g. die rolls) based on some unobserved parameter values (e.g. the 's in the last two posts), along with a prior distribution over . We then need to compute

which will be a hairy high-dimensional integral.

Some special model structures allow us to simplify the problem, typically by factoring the integral into a product of one-dimensional integrals. But in general, we need some method for approximating these integrals.

The two most common approximation methods used in practice are Laplace approximation around the maximum-likelihood point, and MCMC (see e.g. here for application of MCMC to Bayes factors). We'll mainly talk about Laplace approximation here - in practice MCMC mostly works well in the same cases, assuming the unobserved parameters are continuous.

Laplace Approximation

Here's the idea of Laplace approximation. First, posterior distributions tend to be very pointy. This is mainly because independent probabilities multiply, so probabilities tend to scale exponentially with the number of data points. Think of the probabilities we calculated in the last two posts, with values like or - that's the typical case. If we're integrating over a function with values like that, we can basically just pay attention to the region around the highest value - other regions will have exponentially small weight.

Laplace' trick is to use a second-order approximation within that high-valued region. Specifically, since probabilities naturally live on a log scale, we'll take a second order-approximation of the log likelihood around its maximum point. Thus:

If we assume that the prior is uniform (i.e. ), then this looks like a normal distribution on with mean and variance given by the inverse Hessian matrix of the log-likelihood. (It turns out that, even for non-uniform , we can just transform so that the prior looks uniform near , and transform it back when we're done.) The result:

Let's walk through each of those pieces:

  • is the usual maximum likelihood: the largest probability assigned to the data by any particular value of .
  • is the prior probability density of the maximum-likelihood point.
  • is that annoying constant factor which shows up whenever we deal with normal distributions; k is the dimension of .
  • is the determinant of the "Fisher information matrix"; it quantifies how wide or skinny the peak is.

A bit more detail on that last piece: intuitively, each eigenvalue of the Fisher information matrix tells us the approximate width of the peak in a particular direction. Since the matrix is the inverse variance (in one dimension ) of our approximate normal distribution, and "width" of the peak of a normal distribution corresponds to the standard deviation , we use an inverse square root (i.e. the power of ) to extract a width from each eigenvalue. Then, to find how much volume the peak covers, we multiply together the widths along each direction - thus the determinant, which is the product of eigenvalues.

Why do we need eigenvalues? The diagram above shows the general idea: for the function shown, the two arrows would be eigenvectors of the Hessian at the peak. Under a second-order approximation, these are principal axes of the function's level sets (the ellipses in the diagram). They are the natural directions along which to measure the width. The eigenvalue associated with each eigenvector tells us the width, and then taking their product (via the determinant) gives a volume. In the picture above, the determinant would be proportional to the volume of any of the ellipses.

Altogether, then, the Laplace approximation takes the height of the peak (i.e. ) and multiplies by the volume of -space which the peak occupies, based on a second-order approximation of the likelihood around its peak.

Laplace Complexity Penalty

The Laplace approximation contains our first example of an explicit complexity penalty.

The idea of a complexity penalty is that we first find the maximum log likelihood , maybe add a term for our -prior , and that's the "score" of our model. But more general models, with more free parameters, will always score higher, leading to overfit. To counterbalance that, we calculate some numerical penalty which is larger for more complex models (i.e. those with more free parameters) and subtract that penalty from the raw score.

In the case of Laplace approximation, a natural complexity penalty drops out as soon as we take the log of the approximation formula:

The last two terms are the complexity penalty. As we saw above, they give the (log) volume of the likelihood peak in -space. The wider the peak, the larger the chunk of -space which actually predicts the observed data.

There are two main problems with this complexity penalty:

  • First, there's the usual issues with approximating a posterior distribution by looking at a single point. Multimodal distributions are a problem, insufficiently-pointy distributions are a problem. These problems apply to basically any complexity penalty method.
  • Second, although the log determinant of the Hessian can be computed via backpropagation and linear algebra, that computation takes . That's a lot better than the exponential time required for high-dimensional integrals, but still too slow to be practical for large-scale models with millions of parameters.

Historically, a third issue was the math/coding work involved in calculating a Hessian, but modern backprop tools like Tensorflow or autograd make that pretty easy; I expect in the next few years we'll see a lot more people using a Laplace-based complexity penalty directly. The runtime remains a serious problem for large-scale models, however, and that problem is unlikely to be solved any time soon: a linear-time method for computing the Hessian log determinant would yield an matrix multiplication algorithm.


27