Note: These are results which have been in drafts for a year, see discussion about how we have moved on to thinking about these things.
Our team at AI Safety Camp has been working on a project to model the trajectories of language model outputs. We're interested in predicting not just the next token, but the broader path an LLM's generation might take. This post summarizes our key experiments and findings so far.
How accessible is the latent space at representing longer-scale concepts? How can we compress it?
TL;DR: We tried some simple probing experiments to identify "text type." Results seem promising, but this approach is likely not the best way forward for our goals.
Experiment 1: Probing Mean Latent Representations for "Genre"
We trained probes to classify the "genre" of a text chunk based on the mean activation of that chunk.
Preliminaries:
We examined the Mistral 7B Instruct model and a randomly initialized version of Mistral 7B Instruct.
We used two datasets:
Synthetic: Generated outputs from Mistral prompted with phrases such as "Write a fable for children" or "Explain how to implement quicksort." We passed these outputs to GPT-4 Turbo Preview for automatic segmentation and labelling into classes (narrative, list, speech, code, explanation), and then manually reviewed the labels.
Split texts into chunks (initially using GPT-4, later with an algorithmic approach)
Computed the mean activation for each chunk
Trained various classifier probes on these mean representations to infer each chunk’s genre
Compared the performance of probes trained on the trained model vs probes trained on the randomly initialized model
Results:
Genre can be extracted with F1 scores of up to 98% (CORE) and 71% (synthetic), depending on the dataset
Probes using activations from the randomly initialized model performed much worse but were still above chance
There was a surprisingly small gap between probe performance on the trained model and on the randomly initialized model
We are able to extract text genre, but the probes seem to draw more on spurious correlations than on meaningful semantic features.
More information about this experiment is available in our paper.
Can we see different 'sections' of texts appearing distinctly through attention patterns? Can we automate splitting this?
Experiment 2: Analyzing Attention Patterns
We examined attention patterns within and between paragraphs to understand how information flows through the model.
Methodology:
Generated multi-paragraph texts (e.g., recipes with distinct sections)
Extracted attention weights from various layers
Visualized the attention patterns using heatmaps
Results:
Each heatmap shows the similarity between pairs of token representations at a given level of the model. The coordinates correspond to tokens, and lighter areas indicate groups of tokens with greater similarity; these may correspond to coherent text segments or paragraphs.
Observed distinct attention patterns corresponding to paragraph boundaries
Some layers showed clear cross-paragraph attention, while others focused more locally
Note: This analysis was limited to a small number of texts (~2) and would benefit from a larger-scale study
Experiment 3: Chunking Algorithm Based on Attention Patterns
We developed an algorithm to automatically chunk texts based on changes in attention patterns.
Methodology:
Compute cosine similarity between token activations at a chosen layer.
Identify "breaks" where similarity dropped below a threshold.
Use these breaks to define chunk boundaries
Results:
The blue lines are imposed by the clustering algorithm, which detects a transition to the next chunk. The clustering algorithm calculates the clustering on the fly using the activations. The squares around the diagonal are the areas that the clustering algorithm considers to belong to the same chunk
The algorithm successfully identified semantically meaningful chunks in an online fashion
Performance depends on the desired level of "generality" for chunks
While the current implementation works reasonably well, there's room for improvement in accuracy and efficiency
Can we predict future chunks of text using a simple naive method at all?
Experiment 4: Training a Predictor for Future Latent Vectors using Activations
We used the first 100 token activations from a text to predict future activations. Specifically, we attempted to predict either a single future token activation (the 120th token) or an aggregate of activations across tokens 100–120. The activations were extracted from the Mistral model.
Methodology:
Created a dataset:
Took 100 tokens from "The Pile"
Generated 20 additional tokens using Mistral
We trained various models: simple neural networks, transformers, LLMs fine-tuned with LoRA. Either on the newly created dataset or on the full CORPUS dataset
We tried different aggregation/compression methods for the 100 input tokens: mean, max, sum, PCA, and a learned linear projection. These methods are used at the beginning or end of the model
We evaluated predictions using cosine similarity and MSE between predicted and actual mean vectors
Results:
The top-row heatmaps compare model outputs against actual target outputs using cosine similarity and L2 distance; the bottom row compares actual target outputs with one another. The closer the top-row heatmaps are to the bottom-row heatmaps, the closer the model outputs are to the actual results.
Performance was disappointing across various configurations
The outputs heatmaps consistently showed more prominent horizontal lines than vertical lines, indicating that model outputs varied little across different inputs
These results highlight the challenge of predicting future latent states, even with relatively sophisticated methods. Further research is needed to develop more effective prediction strategies.
Experiment 5: Training a Predictor for Future Latent Vectors using Sentence Embedding.
Similar to the previous experiment, but instead of using token activations we embedded the text using the jasper_en_vision_language_v1 sentence embedding model.
Results:
Results (single-layer neural network): the first-row heatmaps show a faint diagonal pattern, indicating weak alignment between the model outputs and the actual targets.
Strangely, the single-layer neural network produced cleaner heatmaps than a more complex model, even when the more complex model had a lower loss. The single-layer model displayed diagonal and vertical structure; when layers were added, only horizontal structure remained
As in the previous experiment, the output heatmaps consistently showed more prominent horizontal than vertical structure, indicating that model outputs varied little across different inputs
Although models with lower loss values should, in theory, perform better, this improvement was not reflected in the heatmaps. This suggests that heatmaps may not be a reliable evaluation method in this context.
Conclusion and Reflections
We can envision that predicting "paragraph-scale" objects for language models has two components:
How do we represent these "paragraph-scale" things at all?
How can we map from some present token activations to some future paragraphs?
I think for our work so far, we possibly erred too much on (1) and not (2), but I am not sure. Additionally, I think we may have erred too much on running more experiments quickly rather than spending more time to think about what experiments to run. Overall, we have learned some valuable lessons, and later continued this line of work with a cohort of SPAR mentees, resulting in this research.