Author: Justin Y. Chen
This blog post covers material for the second lecture of Harvard CS 2881r: AI Safety taught by Boaz Barak. This is an outline of the lecture, see the video and slides for more details.
Hi! I’m Justin, a PhD student at MIT in the Theory of Computation Group. I work on research problems at the intersection of algorithm design and machine learning. Recently, I have been crafting algorithms which interact with powerful but untrustworthy AI models (learning-augmented algorithms and LLM evaluation) or are stable to small changes in their inputs (differential privacy and replicability).
The methods which have succeeded in training modern LLMs have been those which can scale with more data and more compute. The successful approaches are not necessarily designed to make optimal use of available resources but can easily be extended to more and more resources without plateauing in performance. In brief, the predominant techniques are those which are simple, not stupid, and scale. This blog will explore modern machine learning architecture and training for text generation with a focus on why high-level design choices have been made and without delving deeply into technical details.
The key task which LLMs are based on is next-token prediction (NTP): given a sequence of tokens in a corpus of text, predict the most likely next word to occur. This paradigm for text generation is amenable to scaling. The data available for next-token prediction is ubiquitous, any existing piece of text can be turned into training data simply by masking out tokens and asking the model to predict the missing word. Furthermore, next-token prediction across a myriad of environments is challenging, so models have room to improve and get better with more data and model capacity.
The key to modern machine learning computation is parallelism across multiple axes and scales. Modern LLMs are powered by GPUs, massively parallelizable machines which can perform staggering amounts of computation per second. Computation on GPUs is generally bottlenecked by memory requirements and data movement. Given the scale of LLMs, clusters of up to hundreds of thousands of GPUs are used in tandem during training. Parallelism is leveraged across tokens, training batches, weight layers, and even shards of large matrices.
Transformers, the standard base architecture of LLMs, support next-token prediction while being exceptionally suited for parallelism. Many of the operations in a transformer are linear, so commutativity can be leveraged for parallel computation. Furthermore, transformers mimic many distributed models of computation where compute-heavy operations are performed locally on each token with sequential rounds of limited communication between tokens. (See the video/slides for more details.)
LLM training can be broken into three stages:
All three stages use the same underlying mechanism of some form of gradient descent to update the weights of the transformer via backpropagation in order to maximize the likelihood of producing certain tokens and not others. They differ mainly in how these tokens are chosen.
In Pretraining, massive datasets scraped from the internet are used to train basic next-token prediction. The goal is to build a powerful interpolator which can accurately fill in missing words in a diverse (and messy) corpus of text.
SFT is the stage of the model’s training where the model builders have a chance to shape behavior by giving the model curated prompt-response pairs which represent ideal behavior. The prompt is fixed and the model is trained to maximize the likelihood of giving the designated response. SFT gives an opportunity to create an instruction-following chatbot by giving the model examples of how a chatbot interacts with a user.
RL is similar to SFT both in structure and in its goals with some key differences. A prompt is supplied to the model, but unlike SFT, the response is generated by the model: RL trains on-policy while SFT trains off-policy. While SFT optimizes the model by example, in RL, the model optimizes on its own completions. After the response is generated, the model receives one-dimensional feedback (a real-valued score or binary reward) on the response and updates using an optimization method such as Policy Gradient Descent. (In the lecture we worked out the math of how optimizing the expected reward maps out to multiplying the gradient of the log probabilities with the empirical reward.)
In RL with Human Feedback (RLHF), the feedback given to the model is a score given by a human. As the method developed, RLHF morphed to humans comparing multiple generations and giving ordinal scores rather than scoring individual responses. To deal with the high cost of human annotations, some implementations train reward models to emulate the human feedback on larger amounts of data. RLHF can be used to guide the model towards generating responses that people prefer, or those that conform more to certain desirable behavior. Giving feedback only through reward, the model builder is directing the model more through incentive than by example.
RL is also used with verifier feedback (RLVF). In certain scenarios such as for solving coding or math problems, hard feedback, indicating whether the final answer given in a response is correct, is available and may be possible to generate automatically. For these limited (but important) domains, RLVF allows for training at a much greater scale than RLHF since human annotators are not required.
One limitation of next-token prediction via transformers is that the computation time per token is constant. If a task requires more computation, more tokens must be produced. Some tokens may simply be very hard to predict than others. In post-training (SFT and RL), models may be encouraged to produce intermediate tokens (often referred to as chains-of-thought or CoT) which give the model an opportunity to spend more computation on hard problems.
This method underlies “reasoning” models like DeepSeek R1. For R1-Zero, RL with CoT and verifier feedback trained on math problems dramatically increased performance on math benchmarks. These chains trained purely on-policy mixed languages and lacked readability. The full R1 model combined multiple stages of SFT and RL in order to achieve performance with more readable chains.
In the student experiment, Anastasia Ahani, Henry Huang, and Atticus Wang used RL to optimize prompts for producing more coherent or aligned text. See their blogpost for details.
Post-training with SFT and RLHF are the main tools to enforce safety in LLMs. Originally, this was done by encouraging the model to refuse to answer harmful questions. Through examples in SFT and feedback in RLHF, models were trained to simply shut down when their answers could potentially be harmful.
New generations of models attempt to generate safe completions rather than outright refusals. In “deliberative alignment”, the model is given a safety spec in its context and trained via SFT on examples of how to apply the spec in various scenarios. In RLHF, the model is judged on how well it applies the spec. The result of this training is a model which is primed to deliberate over the spec before jumping to answer, and reason through the spec’s guidelines rather than simply refuse to answer prompts.
Chain-of-thought also offers a fragile opportunity for safety monitoring. In “reasoning” models which produce intermediate tokens, those intermediate tokens may signal if a model is thinking of behaving outside of the spec. An important subtlety with RL is that if we use chain-of-thought monitoring to identify negative reward training examples, the result may be that the model lies when it behaves poorly rather than behaving well. Recent consensus across safety teams at large labs has been to avoid optimization pressures on chains-of-thought to attempt to keep them faithful to describing the model’s own behavior. This allows safety monitoring to be used for interventions in production while avoiding its use in training.