Hey there! I'm a journalist working on a book and am very interested in this experiment. I tried to reach you on another platform but thought it wouldn't hurt to leave a comment here, too. Would you be open to talking about this? You can email me at damon[dot]beres[at]gmail[dot]com. Thank you
This is the first in a series of posts on the question:
I'm defining 'input embedding space' as the token embeddings prior to positional encoding.
The basic procedure for obtaining input space gradients is as follows:
The result is a tensor of the same shape as the input embeddings that points in the direction of minimizing the difference between the predicted and target distribution.
Implementation
These experiments were performed with HuggingFace's transformers library and the
ModernBERT-largemodel (Dec 2024).ModernBERT-largewas chosen because:Obtaining input embeddings prior to the application of positional embeddings requires a little bit of reaching into model internals:
We can pass
input_embedsdirectly into the model's forward pass:Finally, we can use torch's built-in
autogradcapabilities to get our gradient:Case Study: Horses and Dogs, Neighs and Barks
To make things more concrete, let's start with two prompts:
The token distributions as predicted by
ModernBERT-largeare, respectively:Representing the left distribution as 🐶 and the right distribution as 🐴, we are computing the gradient of:
with respect to
cross_entropy(🐶,🐴).Which means:
As a gut-check, let's measure the L2 norm of the gradients for each token to give us a rough sense of the "impulse" given by cross entropy on each token:
The tokens with the top 3 gradient L2 norms are "says", "dog" and "animal".
This is encouraging. But are the gradient directions meaningful?
Let's see if any of the gradients point in a neigh-like direction by finding the vocab token with the largest cosine similarity to our gradient:
argmax(cosine_sim(gradient, vocabulary))However, perhaps this is the wrong question to ask. We want to understand if the gradient is heading towards any vocab token starting from the initial embedding:
argmax(vocab, cosine_sim(gradient, vocab - bark))Sadly, this yields the same set of tokens because the gradient vectors are mostly orthogonal to the original embedding (indeed, they all have a cosine similarity of about
-0.01):ADAM on Input Embeddings
Although the early indications are mixed, it would be interesting to try to ADAM optimize the input embeddings.
It does converge (quite rapidly):
Animating the top token probabilities illustrates the convergence quite nicely:
And most encouragingly, " bark" seems to be on the move!
While " bark" is moving, I should point out that the new embedding (we can call it
bark'), is still firmly in " bark" territory. No other vocab token is closer by cosine similarity or euclidean distance.The Euclidean distance between " neigh" and " bark" is around 2.5, and after 500 training steps we have barely traveled 0.8. An extended training run of 10,000 steps still lands
bark'firmly inbarkworld.But has " bark" traveled towards anything in particular?
Indeed - "bark" has traveled more towards neigh than any other token in the vocabulary.
While this is encouraging, the cosine similarity of the heading towards neigh is nothing astonishing: about
0.3.Repeating this exercise over 64 examples, we can see that 'bark' is a bit of an outlier (it was a contrived example). The total L2 token embedding distances per sequence typically level off, while the KL-divergence approaches zero.
Is there any kind of structure about which dimensions are affected? By inspecting a histograms and cumulative density plots of per-dimension movement in input embedding space, it doesn't appear that any particular token was "favored" - all tokens had a roughly equal distribution of embedding dimension displacement. The following histogram from our 64 test examples is typical.
Some Hypotheses
I conjecture that performing gradient descent on input space embeddings is in the "overparameterized regime".
This has some implications for where and how we minimize to nearly zero loss.
Specifically:
The first point is uncontroversial - it is a well known property of high dimensional Euclidean space that all points become "close".
The second point helps explain why loss in the overparameterized regime almost always converges to nearly zero.
The third point explains why we should have no expectation that the point we converge to is in any way interpretable: The global minima manifold is itself quite high dimensional, and only a tiny fraction of the points on it have sensible back-projections.
TLDR; our consistent ability to converge to zero loss, the lack of interpretability of the results, and the relatively short distance our embeddings travel all lend support to the claim that we are seeing a classic loss landscape.
More Validation - Randomized Input Embeddings
But, to further validate our hypotheses about a vast and everywhere-close global minima manifold, we will conduct a final experiment:
ModernBERT-largeinput embeddings.If loss converges and we again observe that the input embeddings do not move "very far" and "level off", this is good evidence for our hypothesis.
Here are the results:
Again - we consistently converge, and not a single token moved enough to back-project to a new token.
This is strong evidence in my opinion that input embeddings is in the over-parameterized regime.
Update: Follow-up with Meta-Llama/Llama-3.2-1B
The architecture of
ModernBERTis divergent enough from "real" LLMs that it is worth seeing if these observations hold for a model in the Llama family.Llama-3.2-1Bwas the obvious choice for me to keep things local.Even so, this model is 7x the size of
ModernBERT. Without any effort put into optimization, a batch of 64 took something like 11 seconds per optimization step on my MacBook Air.Optimization 1: No Hand-Rolled ADAM
I rolled my own ADAM optimizer - this was a mistake. I had good reasons at the time. But I should have found a way to make it work without the out-of-the-box optimizer.
The goal was to only apply updates to non-special tokens. The solution ended up being simple - selectively zero out the gradient after
loss.backward()but beforeopt.step():Optimization 2: Ensuring All Tensors Live on MPS
Macs have a high performance backend for PyTorch called MPS.
My assumption was that PyTorch was automatically placing tensors on this "device". This assumption was incorrect.
This was a straightforward fix - define a device and ensure every tensor is assigned to it.
Optimization 3: Using torch.amp.GradScaler and FP16 Precision
We can quantize gradient calculation, but
FP16has more severe underflow issues thanFP32. PyTorch provides a nice wrapper around optimizers to automatically scale up loss to avoid gradient underflow:Optimization 4: Turning off Gradients / Model Freezing
Since we are optimizing input embeddings, there's no need to store gradient calculations for the model weights. We can "freeze" the model after constructing the graph. This had the largest impact on training time.
Setup
I sampled random input embeddings from a hyper-ellipse fitted to the vocabulary embeddings and assembled 64 "sentences" of varying lengths. I set up an ADAM optimizer to optimize token input embeddings to minimize KL divergence from the distribution produced by completing the sentence "The animal that says 'neigh' is a ". I masked out special tokens to prevent gradient updates on tokens like
<|begin_of_text|>.Results (Random Input Embeddings)
Just like with
ModernBERT, training over 64 examples shows mean KL divergence approaching zero quickly - under1e-2in less than 1,000 iterations.As with
ModernBERT, the input embeddings only traveled a "short" distance. Here is a smoothed histogram of L2 distance traveled per token:Compare this with a smoothed histogram of distances between 1000 pairs of randomly selected token input embeddings from the vocabulary:
The histogram of per-embedding-dimension displacement demonstrate just how little the individual dimensions "wiggled" (nearly all token dimensions moved less than 0.02 units) per token.
However, unlike our ModernBERT experiments, this was enough to move our random input embeddings to back-project to new tokens, as nonsensical as they were:
The reason for the back-projections changing is unclear - we did not see this behavior with
ModernBERT, and the distance our embeddings travel here is demonstrably "small" compared to typical inter-token distance. Maybe tokens are distributed differently in the regions where "junk" tokens live?Conclusions
The evidence we collected seems to support the hypothesis that the global loss minima manifold for
Meta-LLama/Llama-3.2-1Bis "close to everywhere" and easy to find through gradient descent.My current belief is that the intelligence that arises in large-scale models through gradient descent is a combination of three factors:
Our experiment only has property (2), and hence it is unsurprising that the input embeddings do not morph in "intelligent" ways - say, from "bark" to "neigh".
What we did recover was additional evidence in line with current thinking on loss landscapes, which is in itself valuable.
Thank you for coming on this journey with me!