Nice writeup. I wasn't even aware k-means clustering can be viewed from the Variational Bayes framework. In case more perspectives are useful to any readers: When I first tried to learn about this, I found the Pyro Introduction very helpful; because it is split up over a lot of files, I put together these slides for Bayesian Neural Networks, which also start out with a motivation for Variational Bayes.
I've noticed a lack of clear explanations of the fundamental idea behind the use of variational Bayesian methods, so I thought it would be worth writing something here on this topic.
Intractability problems
A problem that comes up routinely in Bayesian inference is the following: suppose that we have a model of an observed variable x in terms of a latent variable z. Our model tells us P(z) and P(x∣z) for any values of x,z. We want to know P(x), which our model does not explicitly give us.
The trivial idea is to use basic probability to express this as
P(x)=∫zP(x,z)dz=∫zP(x∣z)⋅P(z)dz
where we interpret the probabilities as corresponding to probability densities when appropriate. While this is correct and in principle we can compute P(x) this way, in practice the latent space in which z takes values can be high dimensional and this makes the integral intractable to compute. Even if the set is finite, its size often grows exponentially with the problem we wish to study.
For instance, suppose we're trying to solve a clustering problem[1] with a fixed number n=2 of clusters. The latent variable z is a discrete variable taking values in the set {1,2}, so you might think the integral (which in this case will just be a sum) will be easy to compute. However, in fact if we have N points that we wish to cluster into two groups, there are 2N possible latent variable assignments over the entire set of points, so we'd have to compute a sum with 2N terms! This is obviously a problematic situation.
After some thinking about what's going wrong, though, it's easy to come up with an idea of how to make this process more efficient. The problem is that we're integrating over all possible values of z, but in fact most of the values of z are unlikely to have produced x and so will contribute virtually nothing to the integral. So most of our time during the computation is spent on evaluating the integrand at points that contribute virtually nothing to the final answer. Intuitively we want to take this into account in a way that makes the calculation more tractable, even if it means we perform somewhat worse than the ideal pure Bayesian method.
How might we go about doing this? Suppose that we had an oracle that told us what P(z∣x) is. In that case, we could easily compute
P(x)=P(x∣z)P(z)P(z∣x)
for any value of z such that P(z∣x)≠0. In other words, we need to evaluate the integrand only once, and we know exactly by how much to scale up that probability if we want to get our final answer.
This might seem useless because we don't have an oracle that tells us what P(z∣x) is. The key idea of variational Bayes is that while we would need an oracle to get the perfect Bayesian answer, we can get close to this answer with an imperfect oracle.
The imperfect oracle
Let's see how this works. Suppose that we have an "imperfect oracle" Q(z∣x). First, we need to specify in which sense it's imperfect: for this we'll use its Kullback-Leibler divergence with the true distribution P(z∣x). So we assume that we know
DKL(Q(z∣x)∥P(z∣x))=∫zQ(z∣x)log(Q(z∣x)P(z∣x))dz
is small in some sense. If we have access to such a Q, how can we leverage our knowledge to get a better estimate of P(x)?
First, as the KL divergence is essentially an average over the distribution Q, it seems clear that unlike in the calculation involving P(z∣x) we have to average z over the distribution Q(z∣x) at some point in the argument: this is because we only know that Q and P are close in expectation. We also need to somehow introduce logarithms into the argument because the KL divergence tells us about the average value of a particular logarithm.
We have all the key ingredients now. We know we need to work with logarithms, so let's try to calculate logP(x) instead of P(x): as we can easily go from one to the other this is not a serious problem. Furthermore, as z does not appear in logP(x), we have the trivial equality
logP(x)=Ez∼Q(z∣x)[logP(x)]
Now we repeat the previous use of Bayes' theorem to rewrite this as
=Ez∼Q(z∣x)[log(P(x∣z)P(z)P(z∣x))]
It seems like we're stuck here, but we have one final card left to play. We need to introduce the KL divergence into the expression somehow. As we're already averaging z over the right target and the argument of the expectation is a logarithm, what we need to do is to multiply and divide by Q(z∣x) to obtain
=Ez∼Q(z∣x)[log(P(x∣z)P(z)P(z∣x)Q(z∣x)Q(z∣x))]
Now we can break this up into three terms:
=Ez∼Q(z∣x)[logP(x∣z)]−DKL(Q(z∣x)∥P(z))+DKL(Q(z∣x)∥P(z∣x))
and we've reached the kind of expression we were looking for.
The evidence lower bound
It's not obvious what's going on with this expression initially, so let's try to figure it out. The third term is exactly what we initially assumed to be small, so let's ignore it for the moment and focus only on the expression
≈Ez∼Q(z∣x)[logP(x∣z)]−DKL(Q(z∣x)∥P(z))
The question is whether this expression is easy to compute or not, and fortunately the answer is that it's quite easy! We can evaluate the first term by the simple Monte Carlo method of drawing many independent samples z∼Q(z∣x) and evaluating the empirical average, as we know the distribution Q(z∣x) explicitly and it was presumably chosen to be easy to draw samples from. We can evaluate the second term in a similar way, or in some situations we'll even have access to explicit analytic expressions for Q(z∣x) and P(z) that will permit us to obtain a closed form for the second term.
There's something that's even nicer about the expression we obtained. Let's bring back the term we omitted because we had assumed it to be small, so we get the exact equality
logP(x)=Ez∼Q(z∣x)[logP(x∣z)]−DKL(Q(z∣x)∥P(z))+DKL(Q(z∣x)∥P(z∣x))
Notice that this holds for any[2] distribution Q without assumptions on how close Q is to P, so it's an unconditional equality. The key facts to note are that the left hand side is independent of Q and the third term is always nonnegative as it's a KL divergence. So even if we initially don't have a good approximation Q to P, we can pick some model Qθ for Q with parameters θ taking values in e.g. a real vector space, and because of this identity it turns out that minimizing DKL(Qθ(z∣x)∥P(z∣x)) with respect to the parameters θ will be equivalent to maximizing
L(Qθ(⋅∣x))=Ez∼Qθ(z∣x)[logP(x∣z)]−DKL(Qθ(z∣x)∥P(z))
So not only do we have an effective way to compute P(x) given a good approximation Q(z∣x) to P(z∣x), but our method also gives us a way to produce such approximations for free!
As we have
logP(x)=L(Q(⋅∣x)))+DKL(Q(z∣x)∥P(z∣x))≥L(Q(⋅∣x))
the expression L(Q) is called the "evidence lower bound": it's a lower bound on the evidence logP(x). So for example in an application where P itself is a parametrized model, we can get the maximum likelihood estimates for those parameters in a quasi-Bayesian way by making use of the evidence lower bound and simply maximizing the right hand side of this inequality instead. This is essentially how variational autoencoders work, after taking care to set things up such that the evidence lower bound can be backpropagated through.
Application: naive k-means clustering
To finish the post off, I want to illustrate how the methods developed in the post so far can be used to tackle a concrete problem. Let's take the clustering setup from earlier on in the post: we have some points x={x1,x2,…,xn}⊂Rd for some d. We want to find the best way to split these up into k clusters, where a cluster roughly means "a collection of points that are clumped together". How can we do this?
Problem setup
Well, let's first specify what we expect the data generating process here to look like. We can reasonably assume that each cluster can be modeled as being drawn from a simple unimodal distribution such as a normal distribution. In general we can allow these normal distributions to have arbitrary covariance matrices, but for simplicity here I'll assume all their covariance matrices are equal to the identity matrix and all we have to figure out are their means. In conclusion, we have free parameters μ=(μ1,μ2,…,μk) we wish to identify.
If we knew which points belonged to which clusters, we could do this easily by computing the empirical means and covariances of the points in the cluster, for instance. However, as we don't have this knowledge, there is a latent variable in the problem: the variable encoding which points belong to which clusters. Let's denote this latent variable by z: it is an n-tuple of elements from {1,2,…,k} representing which cluster each point belongs to.
So we can assume that Pμ(x∣z) is given by the product of the probabilities implied by the normal distributions of each point, computed using μ and the cluster that z says that point belongs to; and Pμ(z) is independent of μ and is simply a uniform distribution over all possible values of z. We're searching for the maximum likelihood estimate of the parameters μ.[3] To do this, we need to actually be able to compute Pμ(x) so that we can maximize it, but doing so runs into the intractability problems discussed in the first section.
Applying variational methods
Well, we now have a machinery to deal with this! A reasonable choice of Q is the following: take it to also be parametrized by k points ν=(ν1,ν2,…,νk), and for simplicity take Qν(z∣x) to be a crude approximation which finds the closest point to x out of the k points that are part of ν, puts a probability 1 on it and puts a probability of 0 on everything else.
Now let's think what happens if we try to maximize the evidence lower bound here. Recall that the bound was
L(Qν(⋅∣x))=Ez∼Qν(z∣x)[logPμ(x∣z)]−DKL(Qν(z∣x)∥Pμ(z))
If we hold μ fixed, then it's easy to maximize ν by virtue of the fact that we've picked Q to be deterministic, i.e. to give probability 1 to one particular assignment of points to clusters and 0 to everything else. This is because the evidence lower bound has the alternative expression
=Ez∼Qν(z∣x)[logPμ(x,z)]+H(Qν(z∣x))
and by virtue of Q being deterministic, its entropy will just vanish and we'll be able to drop the second term here. This leaves us with
L(Qν(⋅∣x))=Ez∼Qν(z∣x)[logPμ(x,z)]
What does maximizing this with respect to ν look like? Well, we need to pick ν such that the implied deterministic clustering z maximizes Pμ(x∣z). A little thought leads to the conclusion that in our specific case this problem is solved by choosing ν such that each point is assigned to the cluster of the point in μ that is closest to it (this relies crucially on the assumption that the covariance matrices are all identical and diagonal). So holding μ fixed, maximizing ν looks like assigning each point to the cluster identified by the point in μ closest to it.
Now we have to do the opposite case: maximizing μ holding ν fixed. As Pμ(x,z)∼Pμ(x∣z) by the assumption that Pμ(z) is independent of μ and uniform, we just need to find μ that maximizes the likelihood of observing some points x conditional on the clusters each point belongs to, and we already know how to do this! We just set μi to be equal to the mean of the points in cluster i.
The final algorithm
The algorithm we end up with is the following:
This algorithm is sometimes known as the naive k-means clustering algorithm and has very good performance in practice: you can see it in action in this video. We can also see the reason why it would not be optimal: the choice we made for Q was extremely crude and low capacity. However, the evidence lower bound is sufficiently powerful that it can make even this low-quality approximation yield surprisingly good results.
This means that we have a set of points p1,p2,…,pn∈Rd for some d and we wish to split them "naturally" into two sets of clusters. ↩︎
With some reasonable assumption about the support of Q so we don't have problems of dividing by, or taking the logarithm of, zero. ↩︎
Note that as we're doing maximum likelihood estimation, which is a frequentist method, what I'm outlining is quasi-Bayesian at best. ↩︎