Tomek Korbak

I'm a PhD student at University of Sussex working on aligning language models with human preferences

RL with KL penalties is better seen as Bayesian inference

good catch, yes, thanks!

RL with KL penalties is better seen as Bayesian inference

Thanks for sharing your thoughts, I found these remarks extremely insightful!

It seems like ideal way forward is to more accurately capture what you actually care about, then optimize that---staying close to the original distribution feels like more of a hack to me. It seems like you view the original distribution of webtext as more principled or fundamental than I do, but I'm not sure what accounts for that difference.

A reply that comes to mins is that maybe being grounded in human knowledge, reasoning rules and values represented in web text has inherent value? Maybe web text is already approximately aligned with human preferences and you only want tweak that distribution a bit to match true human preferences? Assume that's the case. Then, we can decompose LM alignment into (i) learning web text distribution and (ii) learning how to warp web text distribution. It seems that (ii) is easier than just learning aligned behaviour from scratch: your reward model doesn't have to work well on arbitrary text but only text from distributions similar to webtext.

Another way of phrasing that point: maybe the assumption that you can have a perfect reward model is unrealistic and we can offload some of the complexity of learning a reward model to a prior given by web text? Or more philosophically, if you're a Bayesian, you shouldn't trust your reward model blindly, you should still have some prior.

RL with KL penalties is better seen as Bayesian inference

Do you think these insights would generalise to the case where the language model may be interacting with some system during this fine-tuning phase? For example, if it generates queries to an external search engine or API, or has dialogue with a human, then the optimal policy is no longer equivalent to just generating the correct output distribution, as it now also involves environment observations.

That's a good point and helps to make a distinction between generative models and policies. In the interactive case, your policy pi(a|s) is conditional distribution. You can equivalently view it as a collection of unconditional distributions {pi_s(a)}, one for each s, and for each of these you are likely to also have distribution collapse (single best action for a given state). Arguably, that's what you want in RL.

So I think it mostly comes down to a philosophical difference. Do you want your LM to be a decision-maker acting in a world or a model of a some probability distribution over texts? If you want a decision-maker and training on language is just a scaffolding to get you there, maybe indeed staying close to the original distribution only has instrumental value?

But what if what you want is just an oracle-type conversational AI: a knowledge base and a common-sense reasoner. Maybe in this case staying close to human knowledge and inference rules represented in language is of inherent value?

RL with KL penalties is better seen as Bayesian inference

I'm glad you found our post insightful!

I'm not sure what is the best energy allocation between modelling and inference here. I think, however, that the modelling part is more neglected (the target distribution is rarely even considered as something that can be written down and analysed). Moreover, designing good target distributions can be quite alignment-specific whereas designing algorithms for inference in probabilistic graphical models is an extremely generic research problem so we can expect progress here anyway.

RL with KL penalties is better seen as Bayesian inference

I expect that in the current regime (only optimizing the policy a small amount), any method that does a reasonable job of maximizing reward while controlling how much the policy changes can be made to work in practice

Yes, that seems plausible. Though as you said, most methods that only change the policy a bit (early stopping, clipping in PPO) do that via implicit KL penalties and still can be seen as updating a prior.

there would be an exploration-exploitation trade-off, which is something that the RL perspective may again offer insight into.

Definitely exploration-exploitation issues could make the distribution collapse more severe and traditional RL tricks could help with that. But I still believe distribution collapse does not reduce to insufficient exploration and good exploration alone won't solve it. In this specific instance, failing to find the optimal policy is not the problem, the optimal policy itself is the problem.

I really liked the post and the agenda of improving safety through generative modelling is close to my heart.

But you still need online access to our MDP (i.e. reward function and transition function), don't you? And it's access to MDP that drives novelty and improvement If you were just sampling whole trajectories from the model (asking the model itself to simulate reward function and transition model) and feeding them back into the model, you should expect any change (on average). Your gradients updates will cancel out, that's a consequence of the expected-grad-log-prob lemma (Ex∼πθ∇θlogπθ(x)=0).

It gets more nuanced when you account for doing ancestral sampling, but it adds problems, not solves them:

https://arxiv.org/abs/2110.10819

On the other hand, in their follow-up work on instruction following, OpenAI claimed they used little online data (from fine-tuned policies):

https://arxiv.org/abs/2203.02155

Levine derives that in his control-as-inference tutorial paper (section 2.3). Your expected exponential total reward is pretty close. Not that it acts a bit like an (exponentiated) Q function for your policy: it gives you exp-reward expected after taking action τt at state τ<t and following π thereafter. The exponential works like a soft argmax, so it gives you something like soft Q-learning but not quite: argmax is also over environment dynamics, not only over policy. So it causes an optimism bias: your agent effectively assumes an optimal next state will sampled for it every time, however unlikely would that be. The rest of Levine's paper deals with that.