I've noticed a lack of clear explanations of the fundamental idea behind the use of variational Bayesian methods, so I thought it would be worth writing something here on this topic.

Intractability problems

A problem that comes up routinely in Bayesian inference is the following: suppose that we have a model of an observed variable in terms of a latent variable . Our model tells us and for any values of . We want to know , which our model does not explicitly give us.

The trivial idea is to use basic probability to express this as

where we interpret the probabilities as corresponding to probability densities when appropriate. While this is correct and in principle we can compute this way, in practice the latent space in which takes values can be high dimensional and this makes the integral intractable to compute. Even if the set is finite, its size often grows exponentially with the problem we wish to study.

For instance, suppose we're trying to solve a clustering problem[1] with a fixed number of clusters. The latent variable is a discrete variable taking values in the set , so you might think the integral (which in this case will just be a sum) will be easy to compute. However, in fact if we have points that we wish to cluster into two groups, there are possible latent variable assignments over the entire set of points, so we'd have to compute a sum with terms! This is obviously a problematic situation.

After some thinking about what's going wrong, though, it's easy to come up with an idea of how to make this process more efficient. The problem is that we're integrating over all possible values of , but in fact most of the values of are unlikely to have produced and so will contribute virtually nothing to the integral. So most of our time during the computation is spent on evaluating the integrand at points that contribute virtually nothing to the final answer. Intuitively we want to take this into account in a way that makes the calculation more tractable, even if it means we perform somewhat worse than the ideal pure Bayesian method.

How might we go about doing this? Suppose that we had an oracle that told us what is. In that case, we could easily compute

for any value of such that . In other words, we need to evaluate the integrand only once, and we know exactly by how much to scale up that probability if we want to get our final answer.

This might seem useless because we don't have an oracle that tells us what is. The key idea of variational Bayes is that while we would need an oracle to get the perfect Bayesian answer, we can get close to this answer with an imperfect oracle.

The imperfect oracle

Let's see how this works. Suppose that we have an "imperfect oracle" . First, we need to specify in which sense it's imperfect: for this we'll use its Kullback-Leibler divergence with the true distribution . So we assume that we know

is small in some sense. If we have access to such a , how can we leverage our knowledge to get a better estimate of ?

First, as the KL divergence is essentially an average over the distribution , it seems clear that unlike in the calculation involving we have to average over the distribution at some point in the argument: this is because we only know that and are close in expectation. We also need to somehow introduce logarithms into the argument because the KL divergence tells us about the average value of a particular logarithm.

We have all the key ingredients now. We know we need to work with logarithms, so let's try to calculate instead of : as we can easily go from one to the other this is not a serious problem. Furthermore, as does not appear in , we have the trivial equality

Now we repeat the previous use of Bayes' theorem to rewrite this as

It seems like we're stuck here, but we have one final card left to play. We need to introduce the KL divergence into the expression somehow. As we're already averaging over the right target and the argument of the expectation is a logarithm, what we need to do is to multiply and divide by to obtain

Now we can break this up into three terms:

and we've reached the kind of expression we were looking for.

The evidence lower bound

It's not obvious what's going on with this expression initially, so let's try to figure it out. The third term is exactly what we initially assumed to be small, so let's ignore it for the moment and focus only on the expression

The question is whether this expression is easy to compute or not, and fortunately the answer is that it's quite easy! We can evaluate the first term by the simple Monte Carlo method of drawing many independent samples and evaluating the empirical average, as we know the distribution explicitly and it was presumably chosen to be easy to draw samples from. We can evaluate the second term in a similar way, or in some situations we'll even have access to explicit analytic expressions for and that will permit us to obtain a closed form for the second term.

There's something that's even nicer about the expression we obtained. Let's bring back the term we omitted because we had assumed it to be small, so we get the exact equality

Notice that this holds for any[2] distribution without assumptions on how close is to , so it's an unconditional equality. The key facts to note are that the left hand side is independent of and the third term is always nonnegative as it's a KL divergence. So even if we initially don't have a good approximation to , we can pick some model for with parameters taking values in e.g. a real vector space, and because of this identity it turns out that minimizing with respect to the parameters will be equivalent to maximizing

So not only do we have an effective way to compute given a good approximation to , but our method also gives us a way to produce such approximations for free!

As we have

the expression is called the "evidence lower bound": it's a lower bound on the evidence . So for example in an application where itself is a parametrized model, we can get the maximum likelihood estimates for those parameters in a quasi-Bayesian way by making use of the evidence lower bound and simply maximizing the right hand side of this inequality instead. This is essentially how variational autoencoders work, after taking care to set things up such that the evidence lower bound can be backpropagated through.

Application: naive k-means clustering

To finish the post off, I want to illustrate how the methods developed in the post so far can be used to tackle a concrete problem. Let's take the clustering setup from earlier on in the post: we have some points for some . We want to find the best way to split these up into clusters, where a cluster roughly means "a collection of points that are clumped together". How can we do this?

Problem setup

Well, let's first specify what we expect the data generating process here to look like. We can reasonably assume that each cluster can be modeled as being drawn from a simple unimodal distribution such as a normal distribution. In general we can allow these normal distributions to have arbitrary covariance matrices, but for simplicity here I'll assume all their covariance matrices are equal to the identity matrix and all we have to figure out are their means. In conclusion, we have free parameters we wish to identify.

If we knew which points belonged to which clusters, we could do this easily by computing the empirical means and covariances of the points in the cluster, for instance. However, as we don't have this knowledge, there is a latent variable in the problem: the variable encoding which points belong to which clusters. Let's denote this latent variable by : it is an -tuple of elements from representing which cluster each point belongs to.

So we can assume that is given by the product of the probabilities implied by the normal distributions of each point, computed using and the cluster that says that point belongs to; and is independent of and is simply a uniform distribution over all possible values of . We're searching for the maximum likelihood estimate of the parameters .[3] To do this, we need to actually be able to compute so that we can maximize it, but doing so runs into the intractability problems discussed in the first section.

Applying variational methods

Well, we now have a machinery to deal with this! A reasonable choice of is the following: take it to also be parametrized by points , and for simplicity take to be a crude approximation which finds the closest point to out of the points that are part of , puts a probability on it and puts a probability of on everything else.

Now let's think what happens if we try to maximize the evidence lower bound here. Recall that the bound was

If we hold fixed, then it's easy to maximize by virtue of the fact that we've picked to be deterministic, i.e. to give probability to one particular assignment of points to clusters and to everything else. This is because the evidence lower bound has the alternative expression

and by virtue of being deterministic, its entropy will just vanish and we'll be able to drop the second term here. This leaves us with

What does maximizing this with respect to look like? Well, we need to pick such that the implied deterministic clustering maximizes . A little thought leads to the conclusion that in our specific case this problem is solved by choosing such that each point is assigned to the cluster of the point in that is closest to it (this relies crucially on the assumption that the covariance matrices are all identical and diagonal). So holding fixed, maximizing looks like assigning each point to the cluster identified by the point in closest to it.

Now we have to do the opposite case: maximizing holding fixed. As by the assumption that is independent of and uniform, we just need to find that maximizes the likelihood of observing some points conditional on the clusters each point belongs to, and we already know how to do this! We just set to be equal to the mean of the points in cluster .

The final algorithm

The algorithm we end up with is the following:

  1. We're given a collection of points with each .
  2. Initialize the cluster center vector randomly.
  3. For each point , find the that minimizes the distance between and in the Euclidean metric. Set .
  4. For each , set equal to the mean of all points currently assigned to the cluster , i.e. the mean of all points such that .
  5. If the desired degree of accuracy or a previously specified number of iterations is not achieved then go back to step 2, otherwise break the loop.

This algorithm is sometimes known as the naive k-means clustering algorithm and has very good performance in practice: you can see it in action in this video. We can also see the reason why it would not be optimal: the choice we made for was extremely crude and low capacity. However, the evidence lower bound is sufficiently powerful that it can make even this low-quality approximation yield surprisingly good results.

  1. This means that we have a set of points for some and we wish to split them "naturally" into two sets of clusters. ↩︎

  2. With some reasonable assumption about the support of so we don't have problems of dividing by, or taking the logarithm of, zero. ↩︎

  3. Note that as we're doing maximum likelihood estimation, which is a frequentist method, what I'm outlining is quasi-Bayesian at best. ↩︎


New Comment
1 comment, sorted by Click to highlight new comments since: Today at 11:47 PM

Nice writeup. I wasn't even aware k-means clustering can be viewed from the Variational Bayes framework. In case more perspectives are useful to any readers: When I first tried to learn about this, I found the Pyro Introduction very helpful; because it is split up over a lot of files, I put together these slides for Bayesian Neural Networks, which also start out with a motivation for Variational Bayes.