Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

TLDR: KL-regularised RL, widely used as part of RL from human feedback (RLHF), is equivalent to variational inference: approximating a Bayesian posterior which specifies how to update a prior LM to conform with evidence provided by the reward function. The Bayesian perspective makes it clear that KL penalties aren’t a hack; they have a principled justification. It also nicely separates the modelling problem (defining a target distribution specifying the desired behaviour of an LM) and the inference problem (approximating that target distribution). Finally, it suggests that RL is not a good formal framework for thinking about LM alignment.

Introduction

Large language models (LMs) tend to generate outputs that reflect undesirable features of their training data such as offensiveness, social bias, harmfulness or dishonesty. Correcting these biases and constraining LMs to be honest, helpful and harmless is an essential part of the problem of aligning LMs with human preferences (henceforth “LM alignment”). One intuitive approach to LM alignment is reinforcement learning (RL): capturing human preferences as a reward function and training the LM to maximise the reward expected under LM distribution. A practical recipe for implementing this idea is RL from human feedback (RLHF): first, a reward model is trained to predict which of two texts a human prefers and then a pretrained LM is fine-tuned to maximise reward given by the reward model while being penalised for Kullback-Leibler (KL) divergence from its initial distribution. However, despite immense popularity of RLHF, the motivation for this KL penalty is not widely understood.

In this blog post, we discuss an underappreciated perspective on KL-regularised RL — the objective employed by RLHF for aligning LMs — which explains its empirical success. We start with describing a problem that arises from naively applying the standard RL objective: distribution collapse. The optimal policy under the RL objective would be a minimal-entropy LM generating a small set of sequences that obtain the highest reward. Then, we discuss how KL-regularised RL avoids distribution collapse due to its KL penalty. This constraint, we argue, transforms the problem from RL to Bayesian inference: updating a prior to conform with evidence provided by the reward.

Moreover, KL-regularised RL is equivalent to a well-studied approach to solving this inference problem approximately: variational inference. This Bayesian perspective explains how KL-regularised RL avoids the distribution collapse problem and offers a first-principles derivation for its objective. It also moves KL-regularised RL closer to other divergence-minimisation-based approaches to fine-tuning LMs such as GDC, which is not equivalent to RL and naturally avoid the distribution collapse problem. In contrast, RL avoids distribution collapse only with a particular choice of function that make it equivalent to Bayesian inference. This suggests that RL might not be an adequate formal framework for problems such as LM alignment.

Aligning language models via standard RL

Let be the set of sequences of tokens from some vocabulary. An LM can be seen as a probability distribution over . While most modern LMs are autoregressive, for simplicity we will only talk about full sequences, e.g. denotes the probability of a sequence . Similarly, a reward function assigns sequences with scalar rewards . In the context of LM alignment, represents human preferences we want to be aligned with, e.g. a non-offensiveness reward would assign low values to sequences that are offensive.

If is our parametric LM (with parameters ), the RL objective for aligning it with our reward function is just the reward expected under LM distribution:

Intuitively, maximising means sampling a number of sequences from the LM and rewarding the LM for good sequences and penalising for bad ones (e.g. offensive sentences). This approach to LM alignment is appealing in several ways, especially when compared with the standard self-supervised language modelling objective of predicting the next token in a static dataset. Because the samples come from the LM itself (as opposed to a static dataset), the sampling distribution naturally follows what the LM has already learned and the reward is only evaluated on LM’s current best guesses about the correct behaviour. For instance, assume the reward is non-offensiveness and this reward involves, but is not limited to, avoiding curse word. Then, the LM could quickly learn to avoid curses and then focus on avoiding more elaborate forms of toxicity, wasting no time on containing curse words.

The problem with the RL objective is that it treats the LM as a policy, not as a generative model. While a generative model is supposed to capture a diverse distribution of samples, a policy is supposed to chose the optimal action. In the LM context, where we don’t have a notion of state, the RL objective reduces to searching for , the sequence with highest reward. If there is one, the optimal policy is a degenerate, deterministic generative model that puts entire probability mass on that single sequence:

where is a Dirac delta distribution centred on . If there are multiple optimal sequences , probability mass would be put only on them.

This failure mode is not purely theoretical. Empirically, distribution collapse induced by maximising reward manifests as decreased fluency and diversity of samples from the LM, which can be measured in terms of perplexity, entropy and the frequency of repetitions. Degeneration of this kind was observed in multiple language generation tasks ranging from translation, summarisation, story generation, video captioning, dialogue, to code generation and LM debiasing.

Figure 1: Samples from an LM fine-tuned using  with reward  if  contains the word “Paris”,  otherwise. Even though there are infinitely many sentences containing “Paris” and the LM is not rewarded for multiple mentions of “Paris”, it still converges to a very low-entropy policy mentioning Paris as often as possible, just in case. Figure adapted from Khalifa et al., 2021.

Figure 1: Samples from an LM fine-tuned using with reward if contains the word “Paris”, otherwise. Even though there are infinitely many sentences containing “Paris” and the LM is not rewarded for multiple mentions of “Paris”, it still converges to a very low-entropy policy mentioning Paris as often as possible, just in case. Figure adapted from Khalifa et al., 2021.

While the degeneration problem is exacerbated by RL failure modes such as insufficient exploration or reward hacking, it is distinct from exploration-exploitation trade-off or reward misspecification. Even with perfect exploration (if we sampled sequences uniformly from as opposed to sampling from ), the optimal policy will still put all probability mass on . Similarly, even if is a smooth, real-valued function and it perfectly captures human preferences across the whole space of possible sequences and if is truly the best thing, we still wouldn’t want the LM to generate only . Essentially, the distribution collapse problem arises from the fact that the RL objective for LM alignment is flawed: it doesn’t care about preserving distributional properties of an LM and will always penalise the LM for putting any probability mass on non-optimal sequences until the LM collapses into a degenerate distribution.

Fine-tuning language models via KL-regularised RL

Couldn’t we somehow include preserving distributional properties of an LM as part of the reward function? The notion of preserving distributional properties of an LM can be formalised as penalising for Kullback-Leibler (KL) divergence between and some other, pretrained LM (e.g. publicly available GPT2). Typically, is initialised to and then fine-tuned to maximise the following objective:

where the KL is defined as

The first term in is equivalent to while the second additionally constrains to stay close (in terms of KL) to . Almost always some reward needs to be sacrificed for that; the coefficient determines the trade-off of how much reward is needed to justify departing from by a certain distance. This objective is commonly used as part of a popular recipe for fine-tuning LMs termed “RL from Human Feedback” (RLHF) and works surprisingly well in practice.

can easily be reformulated as just expected reward, the standard RL objective. We only have to define a new reward function which incorporates the original reward and the KL penalty, using the definition of KL divergence:

where

This new reward function additionally rewards sequences likely under (therefore fluent) and unlikely under itself (an entropy bonus). But even in this formulation, is not a standard RL objective: now the reward depends on policy parameters , which makes it non-stationary and coupled with . But is framing the maximisation of as RL really necessary? In the next section, we will develop an alternative view of this objective -- as an approximate solution to a Bayesian inference problem -- and argue that it is more appealing than the RL framing.

KL-regularised RL as variational inference

Aligning a pretrained LM with preferences encoded by a reward function is essentially a Bayesian inference problem. Intuitively, Bayesian inference is the problem updating a distribution to conform with new evidence. Given the prior probability of a hypothesis and likelihood of evidence assuming , the posterior probability of is given by the Bayes’ theorem: . In our setting, we’re updating , which is initially equal to a prior to conform with evidence provided by the assumption that is optimal in terms of . A reward function can be represented as a distribution over that makes high-reward sequences more likely than low-reward sequences. A simple way of doing that is exponentiating the reward and then rescaling it to be a normalised probability distribution. Then, the posterior is given by:

where is the prior, is the evidence provided by the reward function (scaled by temperature ) and is a constant ensuring that **is a normalised probability distribution. represents a version updated to account for the reward . It also happens to coincide with the optimal policy for :

Moreover, the KL-regularised RL objective can be cast as minimising the KL divergence between the LM and this target distribution :

That’s a different KL than the KL penalty term we’ve seen before. Minimising this new KL is equivalent to variational inference, a well-known approach to approximating Bayesian inference. More formally, is the evidence lower bound (ELBO) on the log likelihood of being optimal under , assuming a prior . Minimising this bound makes approximate the true posterior . A derivation of these equalities can be found in the appendix below.

Why is this picture insightful? For one, it explains where the KL penalty term in KL-regularised RL's original objective comes from. It is necessary to transform the problem from RL to minimising a divergence from a target distribution . This in turn makes the distributional character of an LM a first-class citizen which explains why KL-regularised RL is able to maintain the fluency and diversity of the original LM .

Separation of modelling and inference

In the last section, we have argued that KL-regularised RL is secretly variational inference and that this vantage points elegantly explains why it works. Here, we explore a different advantage of the Bayesian perspective. Essentially, what it says is that aligning an LM with human preferences is a two-step process:

  1. First, you define a distribution specifying the desired behaviour of your LM. A principled way of doing that is using Bayes’ rule to define a posterior like ,
  2. Second, you figure out how to sample from your posterior.

These two steps roughly correspond to a what’s known as modelling and inference in probabilistic programming. Modelling is encoding your knowledge in probabilistic terms (usually by defining a probabilistic graphical model) while inference corresponds to using this model to answer queries. It’s hard to overstate how useful — theoretically and practically — separating these two concerns could be. Let’s discuss these two steps, separately, below.

Modelling. For LMs, the modelling step is relatively easy: our LM is natively a probability distribution and autoregressive models are great for both sampling and evaluating likelihoods. Most modelling decisions are usually around interpreting human preferences in probabilistic terms. Turning a reward function into a distribution by exponentiating it () is one idea, but there are other ones. Here’s a few:

  1. A standard reward model assigns each sample with a single, scalar score . Maybe we’d like instead to have a model that captures a distribution of human preferences associated with a single sample and use that as part of our posterior?
  2. A simpler variant of previous idea is to use one of multiple ways of eliciting uncertainty estimates from a standard (scalar) reward model. What’s nice about uncertainties is that they tell the LM that some rewards are high-precision — therefore, the LM should update a lot based on them — while others are uncertain (perhaps is out of distribution for the reward model) and the LM should tread lightly.
  3. Finally, maybe our preferences are binary, e.g. the LM can never, ever say anything very offensive but is free to behave normally otherwise. Then, we could define where if contains a curse and otherwise. Then, sequences containing curses have probability zero according to (hence is non-offensive) but all other strings keep the original probability (hence no degeneration).

All the posteriors mentioned above are non-parametric: they exist as mathematical objects, but we don’t known the set of Transformer weights that corresponds to them. Moreover, in general these posteriors lie outside the class of probability distributions representable by a Transformer LM. Figuring out an actual piece of code generating samples matching this posterior distribution constitute the inference problem.

Inference. Broadly, there are two classes of algorithms for inference on probabilistic graphical models: variational inference and sampling-based approaches. Variational inference tries to find the set of Transformer weights that give rise to a distribution closest (in terms of KL) to the true posterior. Sampling-based techniques, such as MCMC, don’t represent the true posterior explicitly, but compute samples from a distribution resembling the true posterior.

In the previous section, we’ve shown that KL-regularised RL corresponds to inference via variational inference. But sampling-based inference algorithms also have analogues for LMs as decoding-time methods. Decoding-time methods boil down to simulating a posterior, aligned LM by modifying the generation procedure applied on top of the original LM. The simplest example is also the most popular alignment method used in multiple production systems: filtering (also known as rejection sampling). You can simulate a non-offensive LM by using the following procedure: if the LM generates an offensive sample, you discard it and try again. More elaborata decoding-time methods include weighted decoding and PPLM.

To summarise, we’ve seen that the Bayesian view provides a nice unifying perspective on fine-tuning and decoding-time approaches to LM alignment. They mirror variational inference and sampling-based inference algorithms for probabilistic graphical models, respectively, with their respective trade-offs (training efficiency vs generation efficiency). But a more fundamental advantage, to our mind, is what we’ve started with: the separation of concerns between defining a desired behaviour of an LM and approximating it. The choice of posterior is independent of how you’d like to approximate it. You can therefore separate two failure modes: misspecifying the model (i.e. not capturing the preferences) and failing to approximate the model well enough. In principle, you could try to approximate KL-regularised RL’s posterior using a fancy decoding algorithm and validate if this distribution indeed captures your preferences, without doing costly training. If there’s an efficient way of doing that, then maybe training an actual LM (one allowing for fast generation) could be delayed until prototyping the posterior is done.

Is RL a good framework for LM alignment?

Let me end with a more philosophical implication of the Bayesian perspective on KL-regularised RL. If it’s the Bayesian perspective that justifies theoretically using KL-regularised RL, is the original perspective — the RL perspective — still useful?

There is a family of other divergence minimisation approaches to fine-tuning LMs which are not equivalent to RL. Take Generative Distributional Control (GDC), an approach to fine-tuning LMs that obtains results comparable with KL-regularised RL but minimises a slightly different divergence:

where is an exponential family distribution similar to . The difference between and is in the order of arguments (forward vs reverse KL). However, is no longer equivalent to RL because the expectation in forward KL divergence is with respect to a **not **. Similarly, standard supervised training objective can be seen as minimising , a divergence from the empirical distribution provided by the training set.

One can therefore mount a double dissociation argument in favour of the divergence minimisation perspective on KL-regularised RL: RL without distribution matching fails, divergence minimisation without RL works. Therefore, it’s the divergence minimisation aspect of KL-regularised RL that accounts for its success. In consequence, calling it RL is just a redescription of it that happens to be correct under a particular choice of reward function , but does not provide motivation for this choice of and does not hold for alternative divergence minimisation approaches to fine-tuning LMs such as GDC.

The divergence minimisation perspective on KL-regularised RL we presented stems from a general framework known as control as inference. Control as inference provides a formalisation of intelligent decision making as inference on a probabilistic graphical model representing the agent, its preferences and environmental dynamics. While control as inference is typically considered with graphical models parameterised to make it equivalent to RL, it does not have to be. Moreover, there are frameworks such as active inference or APD that further generalise control as inference to a general principle of minimising the KL divergence from a probability distribution representing desired behaviour of the agent. In contrast with RL, they conceptualise the agent as a generative model, not as a decision rule represented as a probability distribution out of convenience. Therefore, they naturally avoid the distribution collapse problem and preserve the distributional properties of the agent. What if RL simply isn’t an adequate formal framework for problems such as aligning LMs?

Mathematical appendix

This section is just a step-by-step derivation of the equivalence between KL-regularised RL optimal policy and Bayesian posterior and the equivalence between KL-regularised RL’s objective and variational inference’s ELBO.

Let’s assume we have a prior distribution over sequences of tokens and a reward function which is (for technical reasons) always negative (from to 0). We can also represent as a binary random variable (the optimality variable). if a certain LM is optimal. We can define in terms of as

which is normalised because is always negative. For instance, if is a log probability that a sequence is non-offensive, is a probability that is non-offensive and the marginal is the average offensiveness score of (or a probability that a random sample from is non-offensive). The problem of aligning LMs can be seen as inferring , a distribution over sequences of tokens conditioned on being non-offensive. This can be computed by applying Bayes’ rule as

where we chose the prior , redefined the marginal as the normalising constant , used the definition of and chose . here is equivalent to , the optimal policy under (up to the choice of which can be absorbed into anyways).

is a non-parametric distribution: it doesn’t have to lie in the family of distributions representable by a parametric model. In general, we’d like to find a parametric model closest to . This can be formalised as finding that minimises . Here, however, we will derive this objective from a yet more general perspective: inferring a random latent variable that best explains the assumption that certain LM is optimal given a prior . This can be seen as maximising the log-likelihood of via variational inference:

In this derivation, we first introduce a latent variable using the sum rule of probability, factorise a joint distribution, introduce a variational distribution over that latent variable, use Jensen’s inequality to obtain a bound (ELBo) and, finally, use the definition of .

This new bound can be alternatively expressed in two different ways. One is is just KL-regularised RL objective with :

The second one is proportional (up to a constant ) to negative :

where is the target distribution (or optimal policy for ). Their equivalence proves that KL-regularised reward maximisation is equivalent to minimising divergence from .

This blog post is largely based on a workshop paper with Ethan Perez and Chris Buckley. It also benefited from discussions with and comments from Hady Elsahar, Germán Kruszewski, Marc Dymetman and Jérémy Scheurer.

69

Ω 35

12 comments, sorted by Click to highlight new comments since: Today at 2:50 PM
New Comment

I broadly agree with this perspective, and think the modeling vs inference distinction is a valuable one to make.

That said, it seems to me that in practice you often should be trying to converge to a zero entropy policy. The optimal action is not random; we want the model to be picking the output that looks best in light of its beliefs. (Old Eliezer post.)

For some applications you are sampling from your model multiple times and ensembling. In this case randomness can help if you have no memory. But in the same setting, optimizing the correct reward function ("solve the problem given that none of your previous tries worked") wouldn't involve converging to zero entropy, because the quality of an output decreases naturally as it becomes more similar to your previous outputs. (Though I think this is a minority of language model sampling in practice.)

It seems like ideal way forward is to more accurately capture what you actually care about, then optimize that---staying close to the original distribution feels like more of a hack to me. It seems like you view the original distribution of webtext as more principled or fundamental than I do, but I'm not sure what accounts for that difference. For example, if you want to have a model that creates 10 diverse samples when run 10 times in parallel, I don't see why "stay close to webtext" would be a particularly good way to do that. It seems like incorporating diversity directly into the reward function is the way you'd like to go in the long run, it poses some challenges in the short run, and "match the webtext distribution" is a shortcut you might use before solving the fundamental problem.

I'm personally mostly interested in KL minimization in order to ensure you continue exploring and to ensure the policy changes slowly enough for the reward model to keep up (avoiding Goodhart destroying the approximation quality of the preference model, and perhaps more speculatively avoiding it destroying the approximation quality of a limited overseer human). But as Jacob Hilton points out, right now it seems like a wide range of options will work out OK for that.

Thanks for sharing your thoughts, I found these remarks extremely insightful!

It seems like ideal way forward is to more accurately capture what you actually care about, then optimize that---staying close to the original distribution feels like more of a hack to me. It seems like you view the original distribution of webtext as more principled or fundamental than I do, but I'm not sure what accounts for that difference.

A reply that comes to mins is that maybe being grounded in human knowledge, reasoning rules and values represented in web text has inherent value? Maybe web text is already approximately aligned with human preferences and you only want tweak that distribution a bit to match true human preferences? Assume that's the case. Then, we can decompose LM alignment into (i) learning web text distribution and (ii) learning how to warp web text distribution. It seems that (ii) is easier than just learning aligned behaviour from scratch: your reward model doesn't have to work well on arbitrary text but only text from distributions similar to webtext.

Another way of phrasing that point: maybe the assumption that you can have a perfect reward model is unrealistic and we can offload some of the complexity of learning a reward model to a prior given by web text? Or more philosophically, if you're a Bayesian, you shouldn't trust your reward model blindly, you should still have some prior.

Thanks for this post, it's clear and insightful about RLHF.

From an alignment perspective, would you say that your work gives evidence that we should focus most of the energy on finding guarantees about the distribution that we're aiming for and debugging problems there, rather than thinking about the guarantees of the inference?

(I still expect that we want to understand the inference better and how it can break, but your post seems to push towards a lesser focus on that part)

I'm glad you found our post insightful!

I'm not sure what is the best energy allocation between modelling and inference here. I think, however, that the modelling part is more neglected (the target distribution is rarely even considered as something that can be written down and analysed). Moreover, designing good target distributions can be quite alignment-specific whereas designing algorithms for inference in probabilistic graphical models is an extremely generic research problem so we can expect progress here anyway.

Great post! This seems like a useful perspective to keep in mind.

Somewhat orthogonally to the theoretical picture, I expect that in the current regime (only optimizing the policy a small amount), any method that does a reasonable job of maximizing reward while controlling how much the policy changes can be made to work in practice. For example, if PPO is tuned appropriately, the KL penalty term can be removed from the reward entirely - instead, PPO's implicit "local" KL penalty controls the rate of policy change.

If we were in the regime of optimizing the policy significantly more, experience from traditional RL suggests that there would be an exploration-exploitation trade-off, which is something that the RL perspective may again offer insight into.

I expect that in the current regime (only optimizing the policy a small amount), any method that does a reasonable job of maximizing reward while controlling how much the policy changes can be made to work in practice

Yes, that seems plausible. Though as you said, most methods that only change the policy a bit (early stopping, clipping in PPO) do that via implicit KL penalties and still can be seen as updating a prior.

there would be an exploration-exploitation trade-off, which is something that the RL perspective may again offer insight into.

Definitely exploration-exploitation issues could make the distribution collapse more severe and traditional RL tricks could help with that. But I still believe distribution collapse does not reduce to insufficient exploration and good exploration alone won't solve it. In this specific instance, failing to find the optimal policy is not the problem, the optimal policy itself is the problem.

You mention converging to a deterministic policy is bad because of repetition, but did I miss you addressing that it's also bad because we want diversity? (Edit: now that I reread that sentence, it makes no sense. Sorry!) In some sense we don't want RL in the limit, we want something a little more aware that we want to sample from a distribution and get lots of different continuations that are all pretty good.

Is a(x) in the formulas supposed to be pi_0(x)?

good catch, yes, thanks!

Do you think these insights would generalise to the case where the language model may be interacting with some system during this fine-tuning phase? For example, if it generates queries to an external search engine or API, or has dialogue with a human, then the optimal policy is no longer equivalent to just generating the correct output distribution, as it now also involves environment observations. This setting makes the distinction between a generative model and a policy clearer, and maybe changes the relevance of this statement:

The problem with the RL objective is that it treats the LM as a policy, not as a generative model. While a generative model is supposed to capture a diverse distribution of samples, a policy is supposed to chose the optimal action.

That is, in these settings we do want an optimal policy and not a generative model. This is quite similar to what Paul was saying as well.

Further, if we see aligning language models as a proxy for aligning future strong systems, it seems likely that these systems will be taking multiple steps of interaction in some environment, rather than just generating one (or several) sequences without feedback.

Do you think these insights would generalise to the case where the language model may be interacting with some system during this fine-tuning phase? For example, if it generates queries to an external search engine or API, or has dialogue with a human, then the optimal policy is no longer equivalent to just generating the correct output distribution, as it now also involves environment observations.

That's a good point and helps to make a distinction between generative models and policies. In the interactive case, your policy pi(a|s) is conditional distribution. You can equivalently view it as a collection of unconditional distributions {pi_s(a)}, one for each s, and for each of these you are likely to also have distribution collapse (single best action for a given state). Arguably, that's what you want in RL.

So I think it mostly comes down to a philosophical difference. Do you want your LM to be a decision-maker acting in a world or a model of a some probability distribution over texts? If you want a decision-maker and training on language is just a scaffolding to get you there, maybe indeed staying close to the original distribution only has instrumental value?

But what if what you want is just an oracle-type conversational AI: a knowledge base and a common-sense reasoner. Maybe in this case staying close to human knowledge and inference rules represented in language is of inherent value?

So I think it mostly comes down to a philosophical difference. Do you want your LM to be a decision-maker acting in a world or a model of a some probability distribution over texts? If you want a decision-maker and training on language is just a scaffolding to get you there, maybe indeed staying close to the original distribution only has instrumental value?

But what if what you want is just an oracle-type conversational AI: a knowledge base and a common-sense reasoner. Maybe in this case staying close to human knowledge and inference rules represented in language is of inherent value?

I feel like this wasn't exactly made clear in the post/paper that this was the motivation. You state that distribution collapse is bad without really justifying it (in my reading). Clearly distributional collapse to a degenerate bad output is bad, and also will often stall learning so is bad from an optimisation perspective as well (as it makes exploration much less likely), but this seems different from distributional collapse to the optimal output. For example, when you say

if x∗ is truly the best thing, we still wouldn’t want the LM to generate only x∗

I think I just disagree, in the case where we're considering LLMs as a model for future agent-like systems that we will have to align, which to me is the reason they're useful for alignment research. If there's a normative claim that diversity is important, then you should just have that in your objective/algorithm.

I think the reason KL divergence is included in RLHF is an optimisation hack to make sure it does well. Maybe that's revealed that for alignment we actually wanted the Bayesian posterior distribution you describe rather than just the optimal distribution according to the reward function (i.e. a hardmax rather than a softmax on the reward over trajectories), although that seems to be an empirical question whether that was our preference all along or it's just useful in the current regime.