This post is an informal preliminary writeup of a project that I've been working on with friends and collaborators. Some of the theory was developed jointly with Zohar Ringel, and we hope to write a more formal paper on it this year. Experiments are joint with Lucas Teixeira (and also an extensive use of llm assistants). This work is part of the research agenda we have been running with Lucas Teixeira and Lauren Greenspan at my organization Principles of Intelligence (formerly PIBBSS), and Lauren helped in writing this post.
Introduction
Pre-introduction
I have had trouble writing an introduction to this post. It combines three aspects of interpretability that I like and have thought about in the last two years:
The mean field approach to understanding neural nets. This is a way of describing neural nets as multi-particle statistical theories, and has been applied in contexts of both Bayesian and SGD learning that has had significant success in the last few years. For example the only currently known exact theoretical prediction for the neuron distribution and grokking phase transition in a modular addition instance is obtained via this theory.
This connects to the agenda we are running at PIBBSS on multi-scale structure in neural nets, and to physical theories of renormalization that study when phenomena at one scale can be decoupled from another. The model here exhibits noisy interactions between first-order-independent components[1] mediated by a notion of frustration noise - a kind of "frozen noise" inherited from a distinct scale that we can understand theoretically in a renormalization-adjacent way.
My inclination is to include all of these in this draft. I am tempted to write a very general introduction which tries to introduce all three phenomena and explain how they are linked together via the experimental context I'll present here. However, this would be confusing and hard to read – people tend to like my technical writing much better when I take to heart the advice: "you get about five words".
So for this post I will focus on one aspect of the work, which is related to the physical notion of frustration and the effect this interesting physical structure has on the rest of the system via a source of irreducible noise. Thus if I were to summarize a five-work takeaway I want to explain in this post, it is this:
Loss landscapes have irreducible noise.
In some ways this is obvious: real data is noisy and has lots of randomness. But the present work shows that noise is even more fundamental than randomness from data. Even in settings where the data distribution is highly symmetric and structured and the neural net is trained on infinite data, interference in the network itself (caused by a complex system responding to conflicting incentives) leads to unavoidable fluctuations in the loss landscape. In some sense this is good news for interpretability: physical systems with irreducible noise will often tend towards having more independence between structures at different scales and can be more amenable to causal decoupling and renormalization-based analyses. Moreover, if we can understand the structure and source of the noise, we can incorporate it into interpretability analyses.
In this case the small-scale noise in the landscape is induced by an emergent property of the system at a larger scale. Namely, early in training the weights of this system develop a coarse large-scale structure similar to the discrete up/down spins of a magnet. For reasons related to interference in data with superposition, the discrete structure is frustrated and therefore in some sense forced to be asymmetric and random. The frustrated structures on a large scale interact with small fluctuations of the system on a small scale, and lead to microscopic random ripples that follow a similar physics to the impurities in a semiconducting metal.
These two graphs represent the weights of two fully trained neural nets, trained on similar tasks and with the same number of parameters (each dot in the graph records two parameters). On the left we have a system without frustration or impurities, whereas on the right the learning problem develops frustration (due to superposition) and the weights at the loss minimum get perturbed by asymmetric and essentially random impurity corrections inherited from frustrated structure.
Sparse data with superposition was originally studied in the field of compressed sensing and dictionary learning. Analyzing modern transformer networks using techniques from this field (especially sparse autoencoders) has led to a paradigm shift in interpretability over the last 3 years. Researchers have found that the hidden-layer activation data in transformers is mathematically well-described by sparse superpositional combinations of so-called "dictionary features" or SAE features; moreover these features have nice interpretability properties, and variants of SAE's have provided some of the best unsupervised interpretability techniques used in the last few years[2].
In our CiS paper we theoretically probe the question of whether the efficient linear encoding of sparse data provided by superposition can extend to an efficient compression of a computation involving sparse data (and we find that indeed this is possible, at least in theory). A key part of our analysis hinges on managing a certain error source called interference. It is equivalent to asking whether a neural net can adequately reconstruct a noisy sparse input in Rd (think of a 2-hot vector like (0,0,1,0,0,0,1,0), with added Gaussian noise while passing it through a narrow hidden layer Rh with h<d. The narrowness h<d forces the neural net to use superposition in the hidden layer (here the input dimension d is the feature dimension. In our CiS paper we use different terminology and the feature dimension is called m). In the presence of superposition, it is actually impossible to get perfect reconstruction, so training with this target is equivalent to asking "how close to the theoretically minimal reconstruction error will be learned by a real neural net".
In the CiS paper we show that a reconstruction that is "close enough" to optimal for our purposes is possible. We do this by "manually" constructing a 2-layer neural net (with one layer of nonlinearity) that achieves a reasonable interference error.
This raises the interesting question of whether an actual 2-layer network trained on this task would learn to manage interference, and what algorithm it would learn if it did; we do not treat this in our CiS paper. It is this sparse denoising task that I am interested in for the present work. The task is similar in spirit to Christ Olah's Toy Models of Superposition (TMS) setting.
Below, I explain the details of the task and the architecture. For convenience of readers who are less interested in details and want to get to the general loss-landscape and physical aspects of the work, the rest of the section is collapsible.
The details (reconstructing sparse inputs via superposition in a hidden layer)
Our model differs from Chris Olah's TMS task in two essential ways. First, the nonlinearity is applied in the hidden layer (as is more conventional) rather than in the last layer as in that work. Second and crucially, while the TMS task's hidden layer (where superposition occurs) is 2-dimensional, we are interested in contexts where the hidden layer is itself high-dimensional (though not as large as the input and output layers), as this is the context where the compressibility benefit from compressed sensing really shine. Less essentially, our data has noise (rather than being a straight autoencoder as in TMS), and the loss is modified to avoid a degeneracy of MSE loss in this setting.
The task: details
At the end of the day the task we want our neural net to try to learn is as follows.
Input: our input in Rd is a sparse boolean vector plus noise, x∗+ξ (here in experiments, x∗ is a 2-hot or 3-hot vector, so something like x=(0,0,1,0,0,0,1,0) and ξ is Gaussian noise which is standard deviation ≪1 on a coordinatewise level). We think of Rd as the feature basis.
Output: the "denoised" vector y(x)=x∗. As a function, this is equivalent to learning the coordinatewise "round to nearest integer" function.
Architecture: we use an architecture that forces superposition, with a single narrow hidden layer Rh, where h≪d. Otherwise the architecture is a conventional neural net architecture with a coordinatewise nonlinearity at the middle layer, so the layers are:
Linear(d→h)∣Nonlinearityh∣Linear(h→d).
Fine print: There are subtleties that I don't want to spend too much time on, about what loss and nonlinearity I'm using in the task, and what specific range of values I'm using for the hidden dimension h and the feature dimension d. In experiments I use a specific class of choices that make the theory easier and that let us see very nice theory-experimental agreement. In particular for loss, a usual MSE loss has bad denoising properties since it encourages returning zero in the regime of interest (this is because sparsity means that most of my "denoised coordinates" should be 0, and removing all noise from these by returning 0 is essentially optimal from a "straight MSE loss" perspective, though it's useless from the perspective of sparse computation). The loss I actually use is re-balanced to discourage this. Also for theory reasons, things are cleaner if we don't include biases and the nonlinearity function has some specific exotic properties (specifically, I use a function which is analytic, odd, has f′(0)=0, and is bounded. The specific function I use is ϕ(x)=tanh(x)3). The theory is also nicer in a specific (though extensive) range of values for the feature dimension d and the hidden dimension h<d where superposition occur.
Much of the theory can be extended to apply in a more general setting (a big difference will be that in general, the main theoretical tool we use would shift from single-particle mean field theory to so-called multi-particle mean field theory).
The setting that would be perhaps most natural for modeling realistic NNs in practice is to use a ReLu (or another "standard" activation function) and a softmax error loss for the reconstruction. While theoretically less tractable and experimentally messier, I expect that some of the key heuristic behaviors would be the same in this case. In particular both theory and experiment recovers in this setting the discrete "frustration" phenomenon in the signs of the weights.
Results
The discussion here will orient around the following figure. Both sides represent a trained denoising model as above, but only the model on the right has superposition. The difference in the resulting weights is a visual representation of certain phenomena associated with frustrated physical systems.
Here each image represents the set of weight coordinates in an instance of the model trained to completion on (effectively) infinite data. On the left we have trained a model without superposition: the hidden dimension is larger than the input dimension. On the right, we are training a model with superposition: the model has input dimension larger than the hidden dimension (by about a factor of 3). The two models are chosen in such a way that the total number of weights (i.e., the total number of dots in the diagram) is the same in both.
Both images represent a (single) local minimum in the weight landscape, and they have a similar coarse structure, comprising a set of several clusters of similar shape (if you zoom in, you can see that the decoder weights end up scaled differently and the encoder weights are distributed differently between the bumps – we can predict these differences from a theoretical model). But we also see a very stark visual difference between the two sets of weights. The model on the left seems to converge to a highly symmetric set of weights with just 5 allowable coefficient values in the encoder and decoder (up to small noise from numerical imperfections). But despite having the same symmetries a priori, the weights of the model on the right have only a coarse regularity: locally within clusters they look essentially random. In fact in a theoretical model that is guiding my experiments here, we can predict that in an appropriate asymptotic limit, the perturbations around the cluster centers will converge to Gaussian noise.
This noisy structure is a neural-net consequence of a frustrated equilibrium associated with superposition. I'll explain this in the next section.
The physics of frustration
Frustration is one of several ways in statistical physics that can lead to quenched randomness. By its very nature, statistical physics deals with phenomena with randomness and noise. The sources of randomness can roughly be separated into: on the one hand, fluctuation phenomena experienced by systems at positive temperature (think of the random distribution of gas particles in a box). And on the other hand, quenched or frozen phenomena, which subject the physical forces defining the system of interest to random perturbations or fields that are stable at relevant time or energy scales. As an example, think of the frozen random nature of the ocean floor. While it changes over millennia, physicists model the peaks and troughs of the bottom as a quenched random structure which is fixed when studying its effect on wave statistics.
In this work I'm most interested in a special kind of quenched randomness. Quenched structure typically comes about in two ways.
As an externally imposed source of randomness that is frozen "at the outset" when defining the problem. For example the random peaks and valleys in the ocean floor in my example of understanding waves is of this type. Another classic field with external quenched randomness is the study of semiconductors. Here impurities in a metal affect the electron conduction properties in locally random ways that can be well-understood macroscopically using the theory of quenched disorder.
As an emergent or self-organizing phenomenon, where the system has no frozen random structure at the outset, but parts of the system spontaneously self-organize into frozen structures that can be treated as fixed and random. A classic example is the self-organization of the magnetic field of a material (like iron) into magnetic domains when subjected to an external magnetic field.
Since learning combines random and structured phenomena, it is frequently studied as a statistical system. In the next sections I'll explain how the three sources of thermodynamic randomness we discussed: tempering, externally imposed quenched randomness, and emergent frozen disorder, relate to the theory of neural nets and how this is related to the notion of frustration and the present experiment.
Tempering and quenched disorder in ML and interpretability
Existing interpretability theory deals frequently with both fluctuation randomness associated with heat (tempering) and with externally imposed frozen randomness from data, and the two are frequently studied together. However the spontaneously emerging/ self-organizing form of frozen randomness is less frequently studied (at least to the best of my knowledge).
Fluctuations typically appear in Bayesian learning theory. In many theoretical contexts (especially ones concerned with heuristic bias of NN learning), it is useful to replace the exact loss-minimization goal of a learning algorithm by a tempered setting where we model the learner as converging to a random distribution on the loss landscape where weights with low loss are more likely than weights with higher loss, according to a Boltzmann statistical law. (This is heuristically similar to randomly sampling a "low-loss valley" in the landscape, i.e. randomly selecting a weight with loss <ϵ away from minimum.) The amount of "randomness" here is controlled by a parameter called the temperature (it is analogous to the loss cutoff ϵ in the loss valley picture). Tempered learners can be designed empirically. There is a learning algorithm that is known to converge to this tempered distribution, which is (stochastic) Langevin descent, or SGLD. Here the model learns via a biased random walk with a bias towards low-loss regions[3].
Externally imposed frozen disorder typically appears in neural net work through the data distribution. When training a model on finite data, the specific datapoints used are assumed to be drawn randomly from a large data distribution. Thus difference between what a model would learn at infinite data vs. what is learned at a small random training set bakes in a source of noise called the data sample noise.
These two sources of noise are known to be related. In particular in contexts where both are present (Bayesian / Langevin learning on finite data), it is often known roughly which source of randomness dominates. For example in work by Watanabe, it is shown that above a certain scale of tempering (very roughly proportional to the inverse number of datapoints), data randomness becomes small compared to the random noise from tempering, and so the statistics becomes independent of the size of the training set.
In a similar setting, Howard et al. use a standard physical tool for studying quenched disorder, called the replica method, to more carefully study a learning system with both data noise and tempered randomness. (The work is done in a very simple linear regression setting, where exact theoretical predictions can be made using renormalization theory. The result finds a similar "transitional scale" to Watanabe's critical temperature range, with a more precise study of exactly how this transition happens.)
The random cluster structure seen in our experiments is a consequence of a new form of frozen disorder, this time self-organizing (analogous to the self-organizing domains in magnetic iron). In our case the self-organizing structure is a consequence of a particularly interesting and unusual type of chaotic behavior of physical systems, which is called frustration.
Frustration in physics
We have been talking about thermodynamics and physical systems without defining terms. Let's fix this. For me, a (statistical) physical system is a high-dimensional data distribution with a notion of energy. A choice of the data is called a state, and its temperature is a number (that "wants to be small" in a statistical sense). Formally, the probability of a state is determined by a Boltzmann distribution at some temperature T. Here (without getting too in the weeds), the temperature determines "how much" the probability distribution is biased to prefer low-energy states. Most importantly for us, at temperature 0 this bias is maximized, and all the probability distribution is concentrated on states that minimize the energy.
A "prototypical" example is the ferromagnetic Ising model, where the state is a collection of spins ±1 indexed by a lattice. We can think of the state as an N×N matrix x all of whose coordinates are ±1. Energy should be a function of such vectors. The energy E(x) of an Ising state is defined to be a certain negative integer. It is minus the count of the number of "neighbor pairs" (edges in the lattice) where both neighbors point the same direction (so are either both up or both down). Since models want to minimize energy, the lowest-energy (and thus highest-probability) state is the one where all spins align.
A related model, called the antiferromagnetic Ising model, is defined similarly but counts the number of anti-aligned pairs (up to a constant, this is just minus the ferromagnetic energy). A minimizer here is the checkerboard pattern of spins:
In the antiferromagnetic model this state is energy-minimizing and energetically favored with energy -4 (4 "opposite" connections). It has energy 0 in the ferromagnetic model.
Energy minimizers of a thermodynamic system are called ground states. At temperature T = 0, the probability distribution is concentrated on ground states only. For the ferromagnetic Ising model, this is a probability distribution with 50% probability on the "spin up" ground state (all spins +1) and 50% on the spin down state. The antiferromagnetic model similarly has two ground states related by a sign flip. The "default" expectation is that[4] there are not very many ground states: that in some sense they form a low-dimensional space of structured distributions that is either unique or at worst controlled by a low-dimensional "macroscopic" parameter.
This expectation of a low-dimensional set of highly structured ground states is sometimes frustrated[5] in an interesting way. In frustrated physical systems, discrete parameters have conflicting energetic preferences that cannot be simultaneously satisfied and cause interesting emergent disorder. The classic example is an anti-ferromagnet with triangular bonds, i.e., a triangular lattice of spins where each pair wants to anti-align.
A piece of a triangular antiferromagnetic model where neighboring spins want to anti-align. When this lattice is extended to infinity, the zero-temperature static theory develops disordered phenomena and behaves simultaneously like ordered "cold" and chaotic "hot" systems, depending on what you measure. Image from Wikipedia.
There is no way to satisfy all local energetic preferences simultaneously, and in fact there is no nice "structured" ground state. Instead the distribution of ground states looks chaotic and high-dimensional. Furthermore it has significant entropy, similar to a "hot" system that is pushed away from its ground state by positive temperature.
In our sparse denoising model, a similar frustrated energy landscape emerges in the presence of superposition. Here the analog of a physical state is a weight configuration, i.e. a pair of matrices encoderij∈Mat(h,d) and decoderji∈Mat(d,h). The analog of energy is the loss function (always understood in the context of infinite data). In the presence of superposition, part of what the loss is trying to do can be summarized as "trying to orthogonally embed many vectors into a lower-dimensional space" (in our case, it's embedding d "feature" vectors into an h-dimensional hidden layer. Superposition means that d>h). Since exact orthogonality is impossible for more than basis-many vectors, there is a potential for frustrated structure. However a priori there is a mismatch: frustrated structures are discrete statements about ground states or minima, but our system is continuous. It may seem hard to say nontrivial things about the discrete structure on the ground states, i.e. exact minima.
But wait: the model we're studying is sort of discrete, and the evidence is staring us in the eyes. Here is a snapshot of the model's training (I chose a pretty snapshot of a smaller model, but you can also look at the late-training images from before).
We see that the weight pairs naturally decompose into three discrete clusters. In particular if we just look at the embedding weights encoder[i,j], we see that each one is either in a larger cluster around 0 or in a smaller cluster around 1.5 (more generally, the "outside" clusters of embedding weights will be centered at some value slightly larger than 1). (Here we see this experimentally, but there is also a theoretical model for predicting the weight behavior based on mean field theory methods that I have in the background when choosing regimes and experiments, that has good agreement with the experiments in this case.)
We can therefore understand the "a posteriori" statistics governing the trained weights wij=encoder[i,j] to be given by the following process[6].
Choose a discrete matrix of "frozen combinatorial signs" S∈Mat(h,d) with coordinates sij∈{−1,0,1} and write w∗ij≈1.5sij, the x-coordinate of the corresponding cluster centroid.
Find a local loss minimum in the basin near this discrete value.
In this specific system, there usually is a unique local minimum in each such basin (note that this will not be the case in similar models with more complicated geometric structure). Thus the local minima of the system that we care about are basin minima associated to a choice of signs sij. It remains to understand which choices of sign matrices are ground states, i.e. approximate global minima. We now have two cases:
There is no superposition, i.e. h>d. In this case, there are enough hidden coordinates (usually called neurons) to have each feature processed by its own designated neuron. We can check here that indeed, the winning configuration here is to have each neuron (hidden coordinate) process at most one feature. This setting doesn't have frustration. Indeed, once we fix some "macroscopic" information, namely how many features are processed by 1 neuron, how many are processed by 2 neurons, etc., all that remains are sign flips and permutation symmetries which in this case are global symmetries of the system, and thus don't affect the loss, geometry, etc.
There is superposition, i.e. h<d. Here (so long as we are in an appropriate regime), a simplified theoretical model based on mean field theory suggests that essentially any iid random configuration is optimal with high probability[7]. More precisely, we take a fix a probability p, and choose each sign sij independently to be zero with probability 1−p, and then ±1 independently[8] each with probability p/2. is optimal up to small errors (for some value of p that can be theoretically predicted). In other words, since the superposition is explicitly encouraging the matrix coefficients to "not be correlated", the optimal choice is to assume exactly no structure, i.e. random structure. In particular the resulting system is deeply frustrated: the set of vacua is very high-dimensional and has nontrivial entropy.
Remark. The "random ground state" model for superpositional systems is simplistic: true vacua likely have some sneaky higher-level structure that we are failing to track that makes some minima slightly better than others, though on small scales that as we'll see are below anything we can control in this context. At the end of the day, the structure we really care about is what gets learned by the learning algorithm. Here we can do various empirical analyses on the weights (e.g. use random graph measurements, or compare the loss in a random basin to the learned loss) to see that an iid random configuration explains the learned minima well.
Continuous fluctuations from frozen structure
At the end of the day, we have a pair of predictions for the discrete sign structure associated to the context with superposition (h<d) and without superposition (h>d). Now we are actually interested in the continuous weight values, which we only know are in a basin associated to a discrete centroid. We can think of the resulting system in two equivalent ways: either as a renormalization setting where the small-scale continuous structure perturbs the energy of the discrete configuration, or as a coupled system where the discrete "frozen" structure couples to a smaller-scale "microscopic" system as a background field. The second setting is easier to think about, and puts us back into the context we had discussed in a physics setting, where some predetermined random structure (in this case appearing in an emergent way from the same model) is treated as a "frozen" source of randomness in the system.
Unpacking this, we see cleanly the difference between the frustrated / superpositional vs. the un-frustrated / single-vacuum settings. In the un-frustrated case, the discrete structure is unique and symmetric. This implies that the continuous perturbations are also symmetric: once we factor in the "macroscopic" information of how many neurons process each feature the weight coordinates are determined uniquely and symmetrically, and the "cluster structure" simplifies to a set of discrete centroids.
In the frustrated case, the random structure is not symmetric. Frustration means that already on the level of discrete sign choices, we can find no exact regularities, and some pairs of rows/ columns overlap more than others. This asymmetry couples with the microscopic system of perturbations away from the cluster centroids, and implies that the same kind of asymmetric structure will be observed there. More precisely, since we can model the sign randomness as iid discrete noise, the effect on the microscopic fluctuations of local minima can again be nicely modeled via the central limit theorem as continuous but "frozen" ripple phenomena that randomly perturb the coordinates of the local minimum of a basin in a predictable fashion. This leads to a phenomenon that appears again and again in physics: frustrated or random discrete structure at a large scale generates random continuous perturbations to local minima (and any other geometric structures) on some related smaller scale. The resulting ripples in the energy landscape are sometimes called "pinning fields".
The random pinning field perturbations explain the surprising noisy structure in the above plot of the weights of the local minimum of our system (analogous to a vacuum in physics). Since we have flattened a pair of matrices into pairs of real numbers, the ordered and symmetric-looking cluster structure here is hiding a frustrated and asymmetric frozen sign pattern. The quenched disorder from this pattern then produces the ripples, or pinning fields, which perturb the local minima away from exact idealized values.
The theory model I've been carrying around in the background (that I won't explain in this post) actually predicts that the random perturbations in each cluster are Gaussian. In the picture, you can see that this isn't empirically the case. While the central cluster does look like a non-uniform Gaussian, the two corner clusters look like they have some more structure – perhaps a further division into two types of points. Likely this is a combination of some subtle structure beyond randomness that is missed in the naive theory model, and the fact that the model is quite small, probably at the very tail end of applicability of the nice high-dimensional compression properties given by compressed sensing theory. (The model here has hidden dimension h=1024 and input, i.e. feature dimension d=3072, just a factor of 3 bigger.)
Miscellanea: pretty pictures and future questions
The interesting non-Gaussian structure in the clusters hints at a depth of phenomena in this simple context whose surface we are only beginning to probe.
As I mentioned the specific architecture I am looking at is designed theory-first: there is a mean field-inspired picture for getting predictions about the weights of these models which is much simpler in some settings and regimes. Here I am taking a regime designed to be particularly amenable to theory.
One can ask what would happen if we were to train a sparse denoising model as experimentalists, or at least in some "generic ML way". To explore this I did some experiments where we take a GeLU activation and train on a softmax loss (the most natural choice since I explained that straight MSE loss finds degenerate solutions here).
The model ends up learning the following weight configuration:
Here there is still some heuristic theory of what is happening which models it as a mix of Gaussian blobs. However in this case it is no longer reasonable to expect the weight pairs to be chosen independently from the blobs in this two-dimensional picture. The neuron pairs in the small partially-obscured blob in the lower left are actually coupled, or correlated with the "cloud" blob in the upper right corner, so the "true blobs" are 4-dimensional. The fancy way of saying this is that the mean field theory here becomes a 2-particle mean field theory. Of course the idealized mean field picture here again is at best a heuristic approximation. For example we see extra emergent stratification structure in the different components that is probably seeing some additional structure that the mean field approximation fails to take into account. Similarly to the setting from the rest of the paper, I expect the empirical outputs to agree more closely with the idealized 2-particle mean field in the limit of higher dimensions.
Aesthetically, I found it fun to see how the model gets trained: to me it looks like a cannon shooting at a cloud. You can see this in this imgur link.
You get interesting diagrams when you run the experiments in settings that are just on the boundary between superposition and no superposition. Here it seems that while there may be a small amount of frustration, there is also some sophisticated regular structure that I don't know how to model. For example here is an image with embedding dimension 768 and hidden dimension 512, also with gelu activation and softmax loss:
As next steps for this project, I am interested in looking at settings with features of different importance, size, or noise, or with additional correlations. I think this can be an interesting source of experimental settings with more than two relevant scales. I've run a couple of experiments with different feature types like the one below, but haven't come up with a setting that is both experimentally and theoretically nice.
I am also very interested in looking at more general sparse computation settings and settings with more than two layers, where apart from denoising, a model learns a boolean task on a model with superposition. Some promising experiments in this direction have been done in Adler-Shavit.
Me and my team at PIBBSS are always looking for collaborators and informal collaborations. If any of these directions appeal to you and you are interested in exploring them, please feel free to reach out to me on lesswrong or at dmitry@pibbss.ai .
For a possible empirical example of such incomplete decoupling in LLMs, see Cloude et al. on Subliminal learning, where aspects of finetuning on one type of data get "frozen in" and can be recovered from behavior on seemingly unrelated data.
As is often the case in applied fields without an established ground unified theory, one must stress that this is a descriptive rather than a prescriptive statement: the sparse structure of data is on observation that the ultimate interpretation of transformers must explain, but it does not claim or imply that sparsity is a complete or sufficient explanation of the data. In fact there is increasing consensus that techniques flowing out of sparsity are insufficient to interpret neural nets by themselves – more structures and theoretical primitives are needed, and sparsity could be either one of the fundamental structures, or alternatively purely a consequence of more fundamental phenomena.
This algorithm is guaranteed to converge to the tempered distribution eventually, but this is in general only guaranteed after exponential time. Nevertheless this algorithm works reasonably well in toy models (for example it converges to the correct distribution in the case of modular addition with MSE loss. Here the Bayesian limit is known via mean field theory). In empirical work by the Timaeus group, Langevin descent away from a local minimum is used very successfully in performing thermodynamic analyses of the local heuristic bias of various models.
There is a theoretically well-understood way to deduce the optimal decoder values from the matrix of optimal encoder values in our context, so it is enough to focus on the statistical behavior of encoder weights here, especially on this level of informal presentation.
This is slightly simplified; in fact, we also need to condition on the property that all the nonzero values in a single row of the embedding matrix have the same sign; equivalently, we choose "zero or nonzero" independently for each matrix coordinate, and choose signs independently for each row.
This post is an informal preliminary writeup of a project that I've been working on with friends and collaborators. Some of the theory was developed jointly with Zohar Ringel, and we hope to write a more formal paper on it this year. Experiments are joint with Lucas Teixeira (and also an extensive use of llm assistants). This work is part of the research agenda we have been running with Lucas Teixeira and Lauren Greenspan at my organization Principles of Intelligence (formerly PIBBSS), and Lauren helped in writing this post.
Introduction
Pre-introduction
I have had trouble writing an introduction to this post. It combines three aspects of interpretability that I like and have thought about in the last two years:
My inclination is to include all of these in this draft. I am tempted to write a very general introduction which tries to introduce all three phenomena and explain how they are linked together via the experimental context I'll present here. However, this would be confusing and hard to read – people tend to like my technical writing much better when I take to heart the advice: "you get about five words".
So for this post I will focus on one aspect of the work, which is related to the physical notion of frustration and the effect this interesting physical structure has on the rest of the system via a source of irreducible noise. Thus if I were to summarize a five-work takeaway I want to explain in this post, it is this:
In some ways this is obvious: real data is noisy and has lots of randomness. But the present work shows that noise is even more fundamental than randomness from data. Even in settings where the data distribution is highly symmetric and structured and the neural net is trained on infinite data, interference in the network itself (caused by a complex system responding to conflicting incentives) leads to unavoidable fluctuations in the loss landscape. In some sense this is good news for interpretability: physical systems with irreducible noise will often tend towards having more independence between structures at different scales and can be more amenable to causal decoupling and renormalization-based analyses. Moreover, if we can understand the structure and source of the noise, we can incorporate it into interpretability analyses.
In this case the small-scale noise in the landscape is induced by an emergent property of the system at a larger scale. Namely, early in training the weights of this system develop a coarse large-scale structure similar to the discrete up/down spins of a magnet. For reasons related to interference in data with superposition, the discrete structure is frustrated and therefore in some sense forced to be asymmetric and random. The frustrated structures on a large scale interact with small fluctuations of the system on a small scale, and lead to microscopic random ripples that follow a similar physics to the impurities in a semiconducting metal.
The task
In this work we analyze a task that originally comes from our paper on Computation in Superposition (CiS).
Sparse data with superposition was originally studied in the field of compressed sensing and dictionary learning. Analyzing modern transformer networks using techniques from this field (especially sparse autoencoders) has led to a paradigm shift in interpretability over the last 3 years. Researchers have found that the hidden-layer activation data in transformers is mathematically well-described by sparse superpositional combinations of so-called "dictionary features" or SAE features; moreover these features have nice interpretability properties, and variants of SAE's have provided some of the best unsupervised interpretability techniques used in the last few years[2].
In our CiS paper we theoretically probe the question of whether the efficient linear encoding of sparse data provided by superposition can extend to an efficient compression of a computation involving sparse data (and we find that indeed this is possible, at least in theory). A key part of our analysis hinges on managing a certain error source called interference. It is equivalent to asking whether a neural net can adequately reconstruct a noisy sparse input in Rd (think of a 2-hot vector like (0,0,1,0,0,0,1,0), with added Gaussian noise while passing it through a narrow hidden layer Rh with h<d. The narrowness h<d forces the neural net to use superposition in the hidden layer (here the input dimension d is the feature dimension. In our CiS paper we use different terminology and the feature dimension is called m). In the presence of superposition, it is actually impossible to get perfect reconstruction, so training with this target is equivalent to asking "how close to the theoretically minimal reconstruction error will be learned by a real neural net".
In the CiS paper we show that a reconstruction that is "close enough" to optimal for our purposes is possible. We do this by "manually" constructing a 2-layer neural net (with one layer of nonlinearity) that achieves a reasonable interference error.
This raises the interesting question of whether an actual 2-layer network trained on this task would learn to manage interference, and what algorithm it would learn if it did; we do not treat this in our CiS paper. It is this sparse denoising task that I am interested in for the present work. The task is similar in spirit to Christ Olah's Toy Models of Superposition (TMS) setting.
Below, I explain the details of the task and the architecture. For convenience of readers who are less interested in details and want to get to the general loss-landscape and physical aspects of the work, the rest of the section is collapsible.
The details (reconstructing sparse inputs via superposition in a hidden layer)
Our model differs from Chris Olah's TMS task in two essential ways. First, the nonlinearity is applied in the hidden layer (as is more conventional) rather than in the last layer as in that work. Second and crucially, while the TMS task's hidden layer (where superposition occurs) is 2-dimensional, we are interested in contexts where the hidden layer is itself high-dimensional (though not as large as the input and output layers), as this is the context where the compressibility benefit from compressed sensing really shine. Less essentially, our data has noise (rather than being a straight autoencoder as in TMS), and the loss is modified to avoid a degeneracy of MSE loss in this setting.
The task: details
At the end of the day the task we want our neural net to try to learn is as follows.
Input: our input in Rd is a sparse boolean vector plus noise, x∗+ξ (here in experiments, x∗ is a 2-hot or 3-hot vector, so something like x=(0,0,1,0,0,0,1,0) and ξ is Gaussian noise which is standard deviation ≪1 on a coordinatewise level). We think of Rd as the feature basis.
Output: the "denoised" vector y(x)=x∗. As a function, this is equivalent to learning the coordinatewise "round to nearest integer" function.
Architecture: we use an architecture that forces superposition, with a single narrow hidden layer Rh, where h≪d. Otherwise the architecture is a conventional neural net architecture with a coordinatewise nonlinearity at the middle layer, so the layers are:
Linear(d→h)∣Nonlinearityh∣Linear(h→d).
Fine print: There are subtleties that I don't want to spend too much time on, about what loss and nonlinearity I'm using in the task, and what specific range of values I'm using for the hidden dimension h and the feature dimension d. In experiments I use a specific class of choices that make the theory easier and that let us see very nice theory-experimental agreement. In particular for loss, a usual MSE loss has bad denoising properties since it encourages returning zero in the regime of interest (this is because sparsity means that most of my "denoised coordinates" should be 0, and removing all noise from these by returning 0 is essentially optimal from a "straight MSE loss" perspective, though it's useless from the perspective of sparse computation). The loss I actually use is re-balanced to discourage this. Also for theory reasons, things are cleaner if we don't include biases and the nonlinearity function has some specific exotic properties (specifically, I use a function which is analytic, odd, has f′(0)=0, and is bounded. The specific function I use is ϕ(x)=tanh(x)3). The theory is also nicer in a specific (though extensive) range of values for the feature dimension d and the hidden dimension h<d where superposition occur.
Much of the theory can be extended to apply in a more general setting (a big difference will be that in general, the main theoretical tool we use would shift from single-particle mean field theory to so-called multi-particle mean field theory).
The setting that would be perhaps most natural for modeling realistic NNs in practice is to use a ReLu (or another "standard" activation function) and a softmax error loss for the reconstruction. While theoretically less tractable and experimentally messier, I expect that some of the key heuristic behaviors would be the same in this case. In particular both theory and experiment recovers in this setting the discrete "frustration" phenomenon in the signs of the weights.
Results
The discussion here will orient around the following figure. Both sides represent a trained denoising model as above, but only the model on the right has superposition. The difference in the resulting weights is a visual representation of certain phenomena associated with frustrated physical systems.
Here each image represents the set of weight coordinates in an instance of the model trained to completion on (effectively) infinite data. On the left we have trained a model without superposition: the hidden dimension is larger than the input dimension. On the right, we are training a model with superposition: the model has input dimension larger than the hidden dimension (by about a factor of 3). The two models are chosen in such a way that the total number of weights (i.e., the total number of dots in the diagram) is the same in both.
Both images represent a (single) local minimum in the weight landscape, and they have a similar coarse structure, comprising a set of several clusters of similar shape (if you zoom in, you can see that the decoder weights end up scaled differently and the encoder weights are distributed differently between the bumps – we can predict these differences from a theoretical model). But we also see a very stark visual difference between the two sets of weights. The model on the left seems to converge to a highly symmetric set of weights with just 5 allowable coefficient values in the encoder and decoder (up to small noise from numerical imperfections). But despite having the same symmetries a priori, the weights of the model on the right have only a coarse regularity: locally within clusters they look essentially random. In fact in a theoretical model that is guiding my experiments here, we can predict that in an appropriate asymptotic limit, the perturbations around the cluster centers will converge to Gaussian noise.
This noisy structure is a neural-net consequence of a frustrated equilibrium associated with superposition. I'll explain this in the next section.
The physics of frustration
Frustration is one of several ways in statistical physics that can lead to quenched randomness. By its very nature, statistical physics deals with phenomena with randomness and noise. The sources of randomness can roughly be separated into: on the one hand, fluctuation phenomena experienced by systems at positive temperature (think of the random distribution of gas particles in a box). And on the other hand, quenched or frozen phenomena, which subject the physical forces defining the system of interest to random perturbations or fields that are stable at relevant time or energy scales. As an example, think of the frozen random nature of the ocean floor. While it changes over millennia, physicists model the peaks and troughs of the bottom as a quenched random structure which is fixed when studying its effect on wave statistics.
In this work I'm most interested in a special kind of quenched randomness. Quenched structure typically comes about in two ways.
Since learning combines random and structured phenomena, it is frequently studied as a statistical system. In the next sections I'll explain how the three sources of thermodynamic randomness we discussed: tempering, externally imposed quenched randomness, and emergent frozen disorder, relate to the theory of neural nets and how this is related to the notion of frustration and the present experiment.
Tempering and quenched disorder in ML and interpretability
Existing interpretability theory deals frequently with both fluctuation randomness associated with heat (tempering) and with externally imposed frozen randomness from data, and the two are frequently studied together. However the spontaneously emerging/ self-organizing form of frozen randomness is less frequently studied (at least to the best of my knowledge).
Fluctuations typically appear in Bayesian learning theory. In many theoretical contexts (especially ones concerned with heuristic bias of NN learning), it is useful to replace the exact loss-minimization goal of a learning algorithm by a tempered setting where we model the learner as converging to a random distribution on the loss landscape where weights with low loss are more likely than weights with higher loss, according to a Boltzmann statistical law. (This is heuristically similar to randomly sampling a "low-loss valley" in the landscape, i.e. randomly selecting a weight with loss <ϵ away from minimum.) The amount of "randomness" here is controlled by a parameter called the temperature (it is analogous to the loss cutoff ϵ in the loss valley picture). Tempered learners can be designed empirically. There is a learning algorithm that is known to converge to this tempered distribution, which is (stochastic) Langevin descent, or SGLD. Here the model learns via a biased random walk with a bias towards low-loss regions[3].
Externally imposed frozen disorder typically appears in neural net work through the data distribution. When training a model on finite data, the specific datapoints used are assumed to be drawn randomly from a large data distribution. Thus difference between what a model would learn at infinite data vs. what is learned at a small random training set bakes in a source of noise called the data sample noise.
These two sources of noise are known to be related. In particular in contexts where both are present (Bayesian / Langevin learning on finite data), it is often known roughly which source of randomness dominates. For example in work by Watanabe, it is shown that above a certain scale of tempering (very roughly proportional to the inverse number of datapoints), data randomness becomes small compared to the random noise from tempering, and so the statistics becomes independent of the size of the training set.
In a similar setting, Howard et al. use a standard physical tool for studying quenched disorder, called the replica method, to more carefully study a learning system with both data noise and tempered randomness. (The work is done in a very simple linear regression setting, where exact theoretical predictions can be made using renormalization theory. The result finds a similar "transitional scale" to Watanabe's critical temperature range, with a more precise study of exactly how this transition happens.)
The random cluster structure seen in our experiments is a consequence of a new form of frozen disorder, this time self-organizing (analogous to the self-organizing domains in magnetic iron). In our case the self-organizing structure is a consequence of a particularly interesting and unusual type of chaotic behavior of physical systems, which is called frustration.
Frustration in physics
We have been talking about thermodynamics and physical systems without defining terms. Let's fix this. For me, a (statistical) physical system is a high-dimensional data distribution with a notion of energy. A choice of the data is called a state, and its temperature is a number (that "wants to be small" in a statistical sense). Formally, the probability of a state is determined by a Boltzmann distribution at some temperature T. Here (without getting too in the weeds), the temperature determines "how much" the probability distribution is biased to prefer low-energy states. Most importantly for us, at temperature 0 this bias is maximized, and all the probability distribution is concentrated on states that minimize the energy.
A "prototypical" example is the ferromagnetic Ising model, where the state is a collection of spins ±1 indexed by a lattice. We can think of the state as an N×N matrix x all of whose coordinates are ±1. Energy should be a function of such vectors. The energy E(x) of an Ising state is defined to be a certain negative integer. It is minus the count of the number of "neighbor pairs" (edges in the lattice) where both neighbors point the same direction (so are either both up or both down). Since models want to minimize energy, the lowest-energy (and thus highest-probability) state is the one where all spins align.
A related model, called the antiferromagnetic Ising model, is defined similarly but counts the number of anti-aligned pairs (up to a constant, this is just minus the ferromagnetic energy). A minimizer here is the checkerboard pattern of spins:
Energy minimizers of a thermodynamic system are called ground states. At temperature T = 0, the probability distribution is concentrated on ground states only. For the ferromagnetic Ising model, this is a probability distribution with 50% probability on the "spin up" ground state (all spins +1) and 50% on the spin down state. The antiferromagnetic model similarly has two ground states related by a sign flip. The "default" expectation is that[4] there are not very many ground states: that in some sense they form a low-dimensional space of structured distributions that is either unique or at worst controlled by a low-dimensional "macroscopic" parameter.
This expectation of a low-dimensional set of highly structured ground states is sometimes frustrated[5] in an interesting way. In frustrated physical systems, discrete parameters have conflicting energetic preferences that cannot be simultaneously satisfied and cause interesting emergent disorder. The classic example is an anti-ferromagnet with triangular bonds, i.e., a triangular lattice of spins where each pair wants to anti-align.
There is no way to satisfy all local energetic preferences simultaneously, and in fact there is no nice "structured" ground state. Instead the distribution of ground states looks chaotic and high-dimensional. Furthermore it has significant entropy, similar to a "hot" system that is pushed away from its ground state by positive temperature.
In our sparse denoising model, a similar frustrated energy landscape emerges in the presence of superposition. Here the analog of a physical state is a weight configuration, i.e. a pair of matrices encoderij∈Mat(h,d) and decoderji∈Mat(d,h). The analog of energy is the loss function (always understood in the context of infinite data). In the presence of superposition, part of what the loss is trying to do can be summarized as "trying to orthogonally embed many vectors into a lower-dimensional space" (in our case, it's embedding d "feature" vectors into an h-dimensional hidden layer. Superposition means that d>h). Since exact orthogonality is impossible for more than basis-many vectors, there is a potential for frustrated structure. However a priori there is a mismatch: frustrated structures are discrete statements about ground states or minima, but our system is continuous. It may seem hard to say nontrivial things about the discrete structure on the ground states, i.e. exact minima.
But wait: the model we're studying is sort of discrete, and the evidence is staring us in the eyes. Here is a snapshot of the model's training (I chose a pretty snapshot of a smaller model, but you can also look at the late-training images from before).
We see that the weight pairs naturally decompose into three discrete clusters. In particular if we just look at the embedding weights encoder[i,j], we see that each one is either in a larger cluster around 0 or in a smaller cluster around 1.5 (more generally, the "outside" clusters of embedding weights will be centered at some value slightly larger than 1). (Here we see this experimentally, but there is also a theoretical model for predicting the weight behavior based on mean field theory methods that I have in the background when choosing regimes and experiments, that has good agreement with the experiments in this case.)
We can therefore understand the "a posteriori" statistics governing the trained weights wij=encoder[i,j] to be given by the following process[6].
In this specific system, there usually is a unique local minimum in each such basin (note that this will not be the case in similar models with more complicated geometric structure). Thus the local minima of the system that we care about are basin minima associated to a choice of signs sij. It remains to understand which choices of sign matrices are ground states, i.e. approximate global minima. We now have two cases:
Remark. The "random ground state" model for superpositional systems is simplistic: true vacua likely have some sneaky higher-level structure that we are failing to track that makes some minima slightly better than others, though on small scales that as we'll see are below anything we can control in this context. At the end of the day, the structure we really care about is what gets learned by the learning algorithm. Here we can do various empirical analyses on the weights (e.g. use random graph measurements, or compare the loss in a random basin to the learned loss) to see that an iid random configuration explains the learned minima well.
Continuous fluctuations from frozen structure
At the end of the day, we have a pair of predictions for the discrete sign structure associated to the context with superposition (h<d) and without superposition (h>d). Now we are actually interested in the continuous weight values, which we only know are in a basin associated to a discrete centroid. We can think of the resulting system in two equivalent ways: either as a renormalization setting where the small-scale continuous structure perturbs the energy of the discrete configuration, or as a coupled system where the discrete "frozen" structure couples to a smaller-scale "microscopic" system as a background field. The second setting is easier to think about, and puts us back into the context we had discussed in a physics setting, where some predetermined random structure (in this case appearing in an emergent way from the same model) is treated as a "frozen" source of randomness in the system.
Unpacking this, we see cleanly the difference between the frustrated / superpositional vs. the un-frustrated / single-vacuum settings. In the un-frustrated case, the discrete structure is unique and symmetric. This implies that the continuous perturbations are also symmetric: once we factor in the "macroscopic" information of how many neurons process each feature the weight coordinates are determined uniquely and symmetrically, and the "cluster structure" simplifies to a set of discrete centroids.
In the frustrated case, the random structure is not symmetric. Frustration means that already on the level of discrete sign choices, we can find no exact regularities, and some pairs of rows/ columns overlap more than others. This asymmetry couples with the microscopic system of perturbations away from the cluster centroids, and implies that the same kind of asymmetric structure will be observed there. More precisely, since we can model the sign randomness as iid discrete noise, the effect on the microscopic fluctuations of local minima can again be nicely modeled via the central limit theorem as continuous but "frozen" ripple phenomena that randomly perturb the coordinates of the local minimum of a basin in a predictable fashion. This leads to a phenomenon that appears again and again in physics: frustrated or random discrete structure at a large scale generates random continuous perturbations to local minima (and any other geometric structures) on some related smaller scale. The resulting ripples in the energy landscape are sometimes called "pinning fields".
The random pinning field perturbations explain the surprising noisy structure in the above plot of the weights of the local minimum of our system (analogous to a vacuum in physics). Since we have flattened a pair of matrices into pairs of real numbers, the ordered and symmetric-looking cluster structure here is hiding a frustrated and asymmetric frozen sign pattern. The quenched disorder from this pattern then produces the ripples, or pinning fields, which perturb the local minima away from exact idealized values.
The theory model I've been carrying around in the background (that I won't explain in this post) actually predicts that the random perturbations in each cluster are Gaussian. In the picture, you can see that this isn't empirically the case. While the central cluster does look like a non-uniform Gaussian, the two corner clusters look like they have some more structure – perhaps a further division into two types of points. Likely this is a combination of some subtle structure beyond randomness that is missed in the naive theory model, and the fact that the model is quite small, probably at the very tail end of applicability of the nice high-dimensional compression properties given by compressed sensing theory. (The model here has hidden dimension h=1024 and input, i.e. feature dimension d=3072, just a factor of 3 bigger.)
Miscellanea: pretty pictures and future questions
The interesting non-Gaussian structure in the clusters hints at a depth of phenomena in this simple context whose surface we are only beginning to probe.
As I mentioned the specific architecture I am looking at is designed theory-first: there is a mean field-inspired picture for getting predictions about the weights of these models which is much simpler in some settings and regimes. Here I am taking a regime designed to be particularly amenable to theory.
One can ask what would happen if we were to train a sparse denoising model as experimentalists, or at least in some "generic ML way". To explore this I did some experiments where we take a GeLU activation and train on a softmax loss (the most natural choice since I explained that straight MSE loss finds degenerate solutions here).
The model ends up learning the following weight configuration:
Here there is still some heuristic theory of what is happening which models it as a mix of Gaussian blobs. However in this case it is no longer reasonable to expect the weight pairs to be chosen independently from the blobs in this two-dimensional picture. The neuron pairs in the small partially-obscured blob in the lower left are actually coupled, or correlated with the "cloud" blob in the upper right corner, so the "true blobs" are 4-dimensional. The fancy way of saying this is that the mean field theory here becomes a 2-particle mean field theory. Of course the idealized mean field picture here again is at best a heuristic approximation. For example we see extra emergent stratification structure in the different components that is probably seeing some additional structure that the mean field approximation fails to take into account. Similarly to the setting from the rest of the paper, I expect the empirical outputs to agree more closely with the idealized 2-particle mean field in the limit of higher dimensions.
Aesthetically, I found it fun to see how the model gets trained: to me it looks like a cannon shooting at a cloud. You can see this in this imgur link.
You get interesting diagrams when you run the experiments in settings that are just on the boundary between superposition and no superposition. Here it seems that while there may be a small amount of frustration, there is also some sophisticated regular structure that I don't know how to model. For example here is an image with embedding dimension 768 and hidden dimension 512, also with gelu activation and softmax loss:
You can see these experiments in https://github.com/mvaintrob/clean-denoising-experiments.
As next steps for this project, I am interested in looking at settings with features of different importance, size, or noise, or with additional correlations. I think this can be an interesting source of experimental settings with more than two relevant scales. I've run a couple of experiments with different feature types like the one below, but haven't come up with a setting that is both experimentally and theoretically nice.
I am also very interested in looking at more general sparse computation settings and settings with more than two layers, where apart from denoising, a model learns a boolean task on a model with superposition. Some promising experiments in this direction have been done in Adler-Shavit.
Me and my team at PIBBSS are always looking for collaborators and informal collaborations. If any of these directions appeal to you and you are interested in exploring them, please feel free to reach out to me on lesswrong or at dmitry@pibbss.ai .
For a possible empirical example of such incomplete decoupling in LLMs, see Cloude et al. on Subliminal learning, where aspects of finetuning on one type of data get "frozen in" and can be recovered from behavior on seemingly unrelated data.
As is often the case in applied fields without an established ground unified theory, one must stress that this is a descriptive rather than a prescriptive statement: the sparse structure of data is on observation that the ultimate interpretation of transformers must explain, but it does not claim or imply that sparsity is a complete or sufficient explanation of the data. In fact there is increasing consensus that techniques flowing out of sparsity are insufficient to interpret neural nets by themselves – more structures and theoretical primitives are needed, and sparsity could be either one of the fundamental structures, or alternatively purely a consequence of more fundamental phenomena.
This algorithm is guaranteed to converge to the tempered distribution eventually, but this is in general only guaranteed after exponential time. Nevertheless this algorithm works reasonably well in toy models (for example it converges to the correct distribution in the case of modular addition with MSE loss. Here the Bayesian limit is known via mean field theory). In empirical work by the Timaeus group, Langevin descent away from a local minimum is used very successfully in performing thermodynamic analyses of the local heuristic bias of various models.
Up to symmetries of the system
see what I did there
There is a theoretically well-understood way to deduce the optimal decoder values from the matrix of optimal encoder values in our context, so it is enough to focus on the statistical behavior of encoder weights here, especially on this level of informal presentation.
up to a small error that is much smaller than the other scales of interest, thus can be ignored.
This is slightly simplified; in fact, we also need to condition on the property that all the nonzero values in a single row of the embedding matrix have the same sign; equivalently, we choose "zero or nonzero" independently for each matrix coordinate, and choose signs independently for each row.