An on-policy, sample-efficient NLP post-training approach not requiring verification
I once tried teaching a simulated 3D humanoid how to dunk[1]using RL - truly the most effective use of my time. If I took a single thing away from that, it was that designing the best reward function is equivalent to believing your agent is conscious and has the single goal to destroy your project.
My point is, RL is already terrible in practice. Then additionally throwing intermediate rewards out the window and overly relying on the most inefficient part of modern LLMs, their autoregressive inference[2], doesn't exactly seem like the play - somehow it is though.
The first attempts did try exactly this - basically an additional model that takes in one 'reasoning step' and spits out a number - the reward. The problem is that we simply don't have any pretraining data for such a reward model. Generating your own data is expensive[3]and not comparable to any pretraining scale.
There's also a whole different problem - expecting such model to be feasible in the first place[4]: even humans very much struggle with rating whether a particular step shows promise - in fact it would make the life of a researcher much more trivial if at any point in time he could simply make such objective and accurate assessment without actually having to work through the idea.[5]
Nevertheless, disregarding GRPO, the average pipeline was just copying the pretrained LLM, slicing off the predictor head, doing some basic finetuning and calling it a brand-new reward model. This works out "okay"-ish for around 100 training steps[6]but once a significant enough distribution shift occurs in the actor, the shallow understanding of the reward model is revealed.
Contrary to all these RL approaches would be normal finetuning - yet this seems to only lead to very shallow understanding like formatting & vocabulary and at best growing knowledge; but not anything we would normally describe as proper learning. It seems that the On-Policy attribute of approaches like GRPO reduces these superficial changes and therefore focuses on more subtle differences.
Distillation seems to perform slightly better in that regard even though teacher-forcing is basically always used - it might be the case that off-policy can be compensated for if the data at least incorporates similar distributions to those seen at inference time, i.e. making mistakes rather than a perfect solution and then fixing them.
This leaves us in an awkward position:
We would like an algorithm that is sample-efficient, on-policy and doesn't require any additional models - and while we are at it, why not desire natively supporting non-verifiable output formats[7]as well?
If reward models in NLP fail because we simply try to adapt the base model to an ill-suited task with little data, why not just stick to what they are actually good at: predicting the next token. Distillation uses this, trying to pass down prediction capabilities, often even forming the loss between the logits rather than simply sampled token and providing incredibly rich signal as a result. But if we don't have such a bigger model, where would the teacher model get its nuance from?
Well, if the model weights are the same, the only difference could be the input - we would need to supply the teacher model with additional information that would reduce the complexity of the task. In its most extreme version, this would simply be a perfect solution to the problem.
To remain in the distributions seen at inference, we additionally need something like student-forcing. Lastly, we need a mechanism that stops "impossible knowledge" to be transmitted into the gradient - the teacher model directly knows the perfect solution but magically stumbling on this solution before even properly analyzing the problem won't lead to better performance once this knowledge is gone.
It's time to put this into more concrete terms:
You have prompt and - is a normal Chain-of-Thought prompt with the problem while supplies both the problem and a solution, asking to attempt the problem normally and only use the solution as hint/help.
You do inference with , generating . This results in and [8]. You now do distillation over [9]with the teacher computing logits using , and the student computing logits using , call the logits and respectively.
Finally to block "impossible knowledge" we choose an aggregation of both and as actual target for the student. This for example could be:
where is a constant controlling the temperature - it makes sense to choose it such that roughly .
This aggregation basically turns the teacher into a helping hand, only reaching out and intervening when the student is getting severely off track and never giving the student a solution it isn't already seriously considering itself.
[Note: Interactive visualizations were here in the original post but cannot be embedded in LessWrong]
There is one major problem with this approach - it requires a model that already is powerful, i.e. something upwards of 20B params. Anything below that can't be expected to properly follow the teacher-prompt to a reasonable degree and leverage the solution intelligently opposed to just blatantly copying or completely forgetting about it 1k tokens in. This might not sound like a problem directly, but it does once you understand that I have a total of 0$ in funding right now.
If anybody with access to a research cluster would be interested in trying this approach on a sufficient scale, I would be more than happy to give it a go - I even have the code already written from some toy tests for this.
On another note, you can apply this aggregation during inference for already - this is useful for very hard problems[10], as it keeps close to a reasonable approach so that actual learning can happen afterwards. To be precise, during inference you would do two forward passes and already compute and sample the next token from it - essentially a mix between teacher-forcing and student-forcing.
Another question is the data - one of the advantages of GRPO is that it required no stepwise solutions anymore, only a final verifiable answer. We could of course just generate a bunch of solutions and using the verifiable answer generate our own stepwise solutions[11]- this would still have a significantly higher sample efficiency than GRPO since the signal we get from one trace is token-wise logit-based targets, unimaginably more dense than a single coefficient indicating whether to in- or decrease the probability of the whole trace.
But I think this approach especially shines in settings with no verifiable answers - which is practically everything if we zoom out. One could imagine a company like OpenAI having tons and tons of chats where users iterate on one initial demand with the Chatbot; something like RL approaches or finetuning can't make any use of this at all. This approach on the other hand can simply accept the final output that the user seemed content with as solution and start this self-distillation from the start of the conversation, while the logit rebalancing takes care of information not yet revealed. And the best thing - all autoregressive inference has already been done; this training will be stupidly fast.
yes, the basketball kind ↩︎
parallelizing this inference, as GRPO, does alleviate the damage but doesn't erase it ↩︎
even when attempting novel schemas like incorporating binary search https://arxiv.org/pdf/2406.06592 - very cool paper ↩︎
given the unspoken constraint of compute for reward model approx. compute for LLM ↩︎
This does seem to manifest in experts to some degree through intuition, which can be very powerful, but it's just as common that two experts intuitions completely oppose each other ↩︎
if the finetuning data is good enough, which always goes hand in hand with an absurd amount of compute spent on it ↩︎
non-verifieable in this context doesn't speak to some impossibility of determining correctness but simply the feasibility of it - checking whether a multiple choice answer is correct is trivial, but what about a deep research report? ↩︎
[X,Y] simply means Y appended to X here, basically just think of pasting the generated tokens under both prompts, respectively ↩︎
by this I mean masking out anything else than g_CoT for the gradient ↩︎
where something like GRPO would utterly fail ↩︎
which seem to perform a lot better than human-written ones ↩︎