This idea appears very similar to the paper "Reinforcement Learning Teachers of Test Time Scaling": https://arxiv.org/abs/2506.08388
just read through it, i agree there are clear similarities in the data generation and preparation process but how it's then actually used to train the model seems quite different imo - they employ RL, particularly even GRPO, which I want to avoid by any cost. Specifically, because they don't use any kind of logit aggregation, they face this issue of "impossible knowledge" or in other words, a terrible teacher. They reconcile this using their KL and log prob reward function but since at this point they are already dealing with traces, they lost out on the rich logit representation.
also, they choose teacher-forcing instead of student-forcing; teacher-forcing is definitely more popular in the literature, but I think what works like R1 showed is that staying on-policy is really key. here even more, the teacher will have something that we can call the "teaching-bias" since it has already seen the solution and if we allow this bias to carry-over throughout generation, all traces will be terrible and the reward function won't particularly fix this.
still definitely nice work, thanks for the link, decent read.
An on-policy, sample-efficient NLP post-training approach not requiring verification
The current state
I once tried teaching a simulated 3D humanoid how to dunk[1]using RL - truly the most effective use of my time. If I took a single thing away from that, it was that designing the best reward function is equivalent to believing your agent is conscious and has the single goal to destroy your project.
My point is, RL is already terrible in practice. Then additionally throwing intermediate rewards out the window and overly relying on the most inefficient part of modern LLMs, their autoregressive inference[2], doesn't exactly seem like the play - somehow it is though.
The first attempts did try exactly this - basically an additional model that takes in one 'reasoning step' and spits out a number - the reward. The problem is that we simply don't have any pretraining data for such a reward model. Generating your own data is expensive[3]and not comparable to any pretraining scale.
There's also a whole different problem - expecting such model to be feasible in the first place[4]: even humans very much struggle with rating whether a particular step shows promise - in fact it would make the life of a researcher much more trivial if at any point in time he could simply make such objective and accurate assessment without actually having to work through the idea.[5]
Nevertheless, disregarding GRPO, the average pipeline was just copying the pretrained LLM, slicing off the predictor head, doing some basic finetuning and calling it a brand-new reward model. This works out "okay"-ish for around 100 training steps[6]but once a significant enough distribution shift occurs in the actor, the shallow understanding of the reward model is revealed.
Contrary to all these RL approaches would be normal finetuning - yet this seems to only lead to very shallow understanding like formatting & vocabulary and at best growing knowledge; but not anything we would normally describe as proper learning. It seems that the On-Policy attribute of approaches like GRPO reduces these superficial changes and therefore focuses on more subtle differences.
Distillation seems to perform slightly better in that regard even though teacher-forcing is basically always used - it might be the case that off-policy can be compensated for if the data at least incorporates similar distributions to those seen at inference time, i.e. making mistakes rather than a perfect solution and then fixing them.
This leaves us in an awkward position:
We would like an algorithm that is sample-efficient, on-policy and doesn't require any additional models - and while we are at it, why not desire natively supporting non-verifiable output formats[7]as well?
Approach
If reward models in NLP fail because we simply try to adapt the base model to an ill-suited task with little data, why not just stick to what they are actually good at: predicting the next token. Distillation uses this, trying to pass down prediction capabilities, often even forming the loss between the logits rather than simply sampled token and providing incredibly rich signal as a result. But if we don't have such a bigger model, where would the teacher model get its nuance from?
Well, if the model weights are the same, the only difference could be the input - we would need to supply the teacher model with additional information that would reduce the complexity of the task. In its most extreme version, this would simply be a perfect solution to the problem.
To remain in the distributions seen at inference, we additionally need something like student-forcing. Lastly, we need a mechanism that stops "impossible knowledge" to be transmitted into the gradient - the teacher model directly knows the perfect solution but magically stumbling on this solution before even properly analyzing the problem won't lead to better performance once this knowledge is gone.
It's time to put this into more concrete terms:
You have prompt pCoTand pTeacher - pCoT is a normal Chain-of-Thought prompt with the problem while pTeacher supplies both the problem and a solution, asking to attempt the problem normally and only use the solution as hint/help.
You do inference with pCoT, generating genCoT. This results in cCoT=[pCoT,genCoT] and cTeacher=[pTeacher,genCoT][8]. You now do distillation over genCoT[9]with the teacher computing logits using cTeacher, and the student computing logits using cCoT, call the logits lTeacher and lCoT respectively.
Finally to block "impossible knowledge" we choose an aggregation of both lTeacher and lCoT as actual target for the student. This for example could be:
ltarget:=0.75⋅lTeacher+lCoTkwhere k is a constant controlling the temperature - it makes sense to choose it such that roughly H(Softmax(ltarget))≈H(Softmax(lCoT)).
This aggregation basically turns the teacher into a helping hand, only reaching out and intervening when the student is getting severely off track and never giving the student a solution it isn't already seriously considering itself.
[Note: Interactive visualizations were here in the original post but cannot be embedded in LessWrong]
Notes
There is one major problem with this approach - it requires a model that already is powerful, i.e. something upwards of 20B params. Anything below that can't be expected to properly follow the teacher-prompt to a reasonable degree and leverage the solution intelligently opposed to just blatantly copying or completely forgetting about it 1k tokens in. This might not sound like a problem directly, but it does once you understand that I have a total of 0$ in funding right now.
If anybody with access to a research cluster would be interested in trying this approach on a sufficient scale, I would be more than happy to give it a go - I even have the code already written from some toy tests for this.
On another note, you can apply this aggregation during inference for genCoT already - this is useful for very hard problems[10], as it keeps genCoT close to a reasonable approach so that actual learning can happen afterwards. To be precise, during inference you would do two forward passes and already compute ltarget and sample the next token from it - essentially a mix between teacher-forcing and student-forcing.
Another question is the data - one of the advantages of GRPO is that it required no stepwise solutions anymore, only a final verifiable answer. We could of course just generate a bunch of solutions and using the verifiable answer generate our own stepwise solutions[11]- this would still have a significantly higher sample efficiency than GRPO since the signal we get from one trace is token-wise logit-based targets, unimaginably more dense than a single coefficient indicating whether to in- or decrease the probability of the whole trace.
But I think this approach especially shines in settings with no verifiable answers - which is practically everything if we zoom out. One could imagine a company like OpenAI having tons and tons of chats where users iterate on one initial demand with the Chatbot; something like RL approaches or finetuning can't make any use of this at all. This approach on the other hand can simply accept the final output that the user seemed content with as solution and start this self-distillation from the start of the conversation, while the logit rebalancing takes care of information not yet revealed. And the best thing - all autoregressive inference has already been done; this training will be stupidly fast.
Footnotes
yes, the basketball kind ↩︎
parallelizing this inference, as GRPO, does alleviate the damage but doesn't erase it ↩︎
even when attempting novel schemas like incorporating binary search https://arxiv.org/pdf/2406.06592 - very cool paper ↩︎
given the unspoken constraint of compute for reward model approx. compute for LLM ↩︎
This does seem to manifest in experts to some degree through intuition, which can be very powerful, but it's just as common that two experts intuitions completely oppose each other ↩︎
if the finetuning data is good enough, which always goes hand in hand with an absurd amount of compute spent on it ↩︎
non-verifieable in this context doesn't speak to some impossibility of determining correctness but simply the feasibility of it - checking whether a multiple choice answer is correct is trivial, but what about a deep research report? ↩︎
[X,Y] simply means Y appended to X here, basically just think of pasting the generated tokens under both prompts, respectively ↩︎
by this I mean masking out anything else than g_CoT for the gradient ↩︎
where something like GRPO would utterly fail ↩︎
which seem to perform a lot better than human-written ones ↩︎