Crossposted from my personal blog.

Epistemic status: Pretty speculative, but there is a surprising amount of circumstantial evidence.

I have been increasingly thinking about NN representations and slowly coming to the conclusion that they are (almost) completely secretly linear inside[1]. This means that, theoretically, if we can understand their directions, we can very easily exert very powerful control on the internal representations, as well as compose and reason about them in a straightforward way. Finding linear directions for a given representation would allow us to arbitrarily amplify or remove it and interpolate along it as desired. We could also then directly 'mix' it with other representations as desired. Measuring these directions during inference would let us detect the degree of each feature that the network assigns to a given input. For instance, this might let us create internal 'lie detectors' (which there is some progress towards) which can tell if the model is telling the truth, or being deceptive.  While nothing is super definitive (and clearly networks are not 100% linear), I think there is a large amount of fairly compelling circumstantial evidence for this position. Namely:

Evidence for this:

  1. All the work from way back when about interpolating through VAE/GAN latent space. I.e. in the latent space of a VAE on CelebA there are natural 'directions' for recognizable semantic features like 'wearing sunglasses' and 'long hair' and linear interpolations along these directions produced highly recognizable images
  2. Rank 1 or low rank editing techniques such as ROME work so well (not perfectly but pretty well). These are effectively just emphasizing a linear direction in the weights.
  3. You can apparently add and mix LoRas and it works about how you would expect.
  4. You can merge totally different models. People working with Stable Diffusion community literally additively merge model weights with weighted sum and it works!
  5. Logit lens works.
  6. SVD directions.
  7. Linear probing definitely gives you a fair amount of signal.
  8. Linear mode connectivity and git rebasin.
  9. Colin Burns' unsupervised linear probing method works even for semantic features like 'truth'.
  10. You can merge together different models finetuned from the same initialization.
  11. You can do a moving average over model checkpoints and this improves performance and is better than any individual checkpoint!
  12. The fact that linear methods work pretty tolerably well in neuroscience.
  13. Various naive diff based editing and control techniques work at all.
  14. General linear transformations between models and modalities.
  15. You can often remove specific concepts from the model by erasing a linear subspace.
  16. Task vectors (and in general things like finetune diffs being composable linearly)
  17. People keep finding linear representations inside of neural networks when doing interpretability or just randomly.

If this is true, then we should be able to achieve quite a high level of control and understanding of NNs solely by straightforward linear methods and interventions. This would mean that deep networks might end up being pretty understandable and controllable artefacts in the near future. Just at this moment, we just have not yet found the right levers yet (or rather lots of existing work does show this but hasn't really been normalized or applied at scale for alignment). Linear-ish network representations are a best case scenario for both interpretability and control.

For a mechanistic, circuits-level understanding, there is still the problem of superposition of the linear representations. However, if the representations are indeed mostly linear than once superposition is solved there seem to be little other obstacles in front of a complete mechanistic understanding of the network. Moreover, superposition is not even a problem for black-box linear methods for controlling and manipulating features where the optimiser handles the superposition for you.

This hypothesis also gets at a set of intuitions I've slowly been developing. Basically, almost all of alignment thinking assumes that NNs are bad -- 'giant inscrutable matrices' -- and success looks like fighting against the NN. This can either be through minimizing the amount of the system that is NN-based, surrounding the NN with monitoring and various other schemes, or by interpreting their internals and trying to find human-understandable circuits inside. I feel like this approach is misguided and makes the problem way more difficult than it needs to be. Instead we should be working with the NNs. Actual NNs appear to be very far from the maximally bad case and appear to possess a number of very convenient properties - including this seeming linearity -- that we should be exploiting rather than ignoring. Especially if this hypothesis is true, then there is just so much control we can get if we just apply black-boxish methods to the right levers. If there is a prevailing linearity, then this should make a number of interpretability methods much more straightforward as well. Solving superposition might just resolve a large degree of the entire problem of interpretability. We may actually be surprisingly close to success at automated interpretability.

Why might networks actually be linear-ish?

1.  Natural abstractions hypothesis. Most abstractions are naturally linear and compositional in some sense (why?).

2.  NNs or SGD has strong Occam's razor priors towards simplicity and linear = simple.

3. Linear and compositional representations are very good for generalisation and compression which becomes increasingly important for underfit networks on large and highly varied natural datasets. This is similar in spirit to the way that biology evolves to be modular.

4.  Architectural evolution. Strongly nonlinear functions are extremely hard to learn with SGD due to poor conditioning. Linear functions are naturally easier to learn and find with SGD. Our networks use almost-linear nonlinearities such as ReLU/GeLU which strongly encourages formation of nearly-linear representations

5.  Some NTK-like theory. Specifically, as NNs get larger, they move less from their initial condition to the solution, so we can increasingly approximate them with linear taylor expansions. If the default 'representations' are linear and Gaussian due to the initialisation of the network, then perhaps SGD just finds solutions very close to the initialisation which preserve most of the properties.

6.  Our brains can only really perceive linear features and so everything we successfully observe in NNs is linear too, we just miss all the massively nonlinear stuff. This is the anthropic argument and would be the failure case. We just miss all the nonlinear stuff and there lies the danger. Also, if we are applying any implicit selection pressure to the model -- for instance optimising against interpretability tools -- then this might push dangerous behaviour into nonlinear representations to evade our sensors.


  1. ^

    Of course the actual function the network implements cannot be completely linear otherwise we would just be doing a glorified (and expensive) linear regression.


New Comment
20 comments, sorted by Click to highlight new comments since: Today at 10:12 PM

Some counter evidence:

  • Kernelized Concept Erasure: concept encodings do have nonlinear components. Nonlinear kernels can erase certain parts of those encodings, but they cannot prevent other types of nonlinear kernels from extracting concept info from other parts of the embedding space.
  • Limitations of the NTK for Understanding Generalization in Deep Learning: the neural tangent kernels of realistic neural networks continuously change throughout their training. Further, neither the initial kernels nor any of the empirical kernels from mid-training can reproduce the asymptotic scaling laws of the actual neural network, which are better than predicted by said kernels.
  • Mechanistic Mode Connectivity: LMs often have non-connected solution basins, which correspond to different underlying mechanisms by which they make their classification decisions.

Thanks for these links! This is exactly what I was looking for as per Cunningham's law. For the mechanistic mode connectivity, I still need to read the paper, but there is definitely a more complex story relating to the symmetries rendering things non-connected by default but once you account for symmetries and project things into an isometric space where all the symmetries are collapsed things become connected and linear again. Is this different to that?


I agree about the NTK. I think this explanation is bad in its specifics although I think the NTK does give useful explanations at a very coarse level of granularity. In general, to put a completely uncalibrated number on it, I feel like NNs are probably '90% linear' in their feature representations. Of course they have to have somewhat nonlinear representations as well. But otoh if we could get 90% of the way to features that would be massive progress and might be relatively easy. 

One other problem of NTK/GP theory is that it isn't able to capture feature learning/transfer learning, and in general starts to break down as models get more complicated. In essence, NTK/GP fails to capture some empirical realities.

From the post "NTK/GP Models of Neural Nets Can't Learn Features":

Since people are talking about the NTK/GP hypothesis of neural nets again, I thought it might be worth bringing up some recent research in the area that casts doubt on their explanatory power. The upshot is: NTK/GP models of neural networks can't learn features. By 'feature learning' I mean the process where intermediate neurons come to represent task-relevant features such as curves, elements of grammar, or cats. Closely related to feature learning is transfer learning, the typical practice whereby a neural net is trained on one task, then 'fine-tuned' with a lower learning to rate to fit another task, usually with less data than the first. This is often a powerful way to approach learning in the low-data regime, but NTK/GP models can't do it at all.

The reason for this is pretty simple. During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all; only the output function does.[1] Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all. Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity. This lack of feature learning implies a lack of meaningful transfer learning as well: since the NTK is just doing linear regression using an (infinite) fixed set of functions, the only 'transfer' that can occur is shifting where the regression starts in this space. This could potentially speed up convergence, but it wouldn't provide any benefits in terms of representation efficiency for tasks with few data points[2]. This property holds for the GP limit as well -- the distribution of functions computed by intermediate neurons doesn't change after conditioning on the outputs, so networks sampled from the GP posterior wouldn't be useful for transfer learning either.

This also makes me skeptical of the Mingard et al. result about SGD being equivalent to picking a random neural net with given performance, given that picking a random net is equivalent to running a GP regression in the wide-width limit. In particular, it makes me skeptical that this result will generalize to the complex models and tasks we care about. 'GP/NTK performs similarly to SGD on simple tasks' has been found before, but it tends to break down as the tasks become more complex.[3]

In essence, NTK/GP can't transfer learn because it stays where it's originally at in the transfer space, and this doesn't change even in the limit of NTK.

A link to the post is below:

Sorry for stupid question but when you say  is “(almost) linear”, I presume that means  and . But what are X & Y here? Activations? Weights? Weight-changes-versus-initialization? Some-kind-of-semantic-meaning-encoded-as-a-high-dimensional-vector-of-pattern-matches-against-every-possible-abstract-concept? More than one of the above? Something else?

For an image-classification network, if we remove the softmax nonlinearity from the very end, then  would represent the input image in pixel space, and  would represent the class logits. Then  would represent an image with two objects leading to an ambiguous classification (high log-probability for both classes), and  would represent higher class certainty (softmax temperature = ) when the image has higher contrast. I guess that kind of makes sense, but yeah, I think for real neural networks, this will only be linear-ish at best.

Well, averaging / adding two images in pixel space usually gives a thing that looks like two semi-transparent images overlaid, as opposed to “an image with two objects”.

If both images have the main object near the middle of the image or taking up most of the space (which is usually the case for single-class photos taken by humans), then yes. Otherwise, summing two images with small, off-center items will just look like a low-contrast, noisy image of two items.

Either way, though, I would expect this to result in class-label ambiguity. However, in some cases of semi-transparent-object-overlay, the overlay may end up mixing features in such a jumbled way that neither of the "true" classes is discernible. This would be a case where the almost-linearity of the network breaks down.

Maybe this linearity story would work better for generative models, where adding latent vector representations of two different objects would lead the network to generate an image with both objects included (an image that would have an ambiguous class label to a second network). It would need to be tested whether this sort of thing happens by default (e.g., with Stable Diffusion) or whether I'm just making stuff up here.

Maybe this linearity story would work better for generative models, where adding latent vector representations of two different objects would lead the network to generate an image with both objects included (an image that would have an ambiguous class label to a second network). It would need to be tested whether this sort of thing happens by default (e.g., with Stable Diffusion) or whether I'm just making stuff up here.

Yes this is exactly right. This is precisely the kind of linearity that I am talking about not the input->output mapping which is clearly nonlinear. The idea being that hidden inside the network is a linear latent space where we can perform linear operations and they (mostly) work. In the points of evidence in the post there is discussion of exactly this kind of latent space editing for stable diffusion. A nice example is this paper. Interestingly this also works for fine-tuning weight diffs for e.g. style transfer.

Natural abstractions hypothesis. Most abstractions are naturally linear and compositional in some sense (why?).

One of my main current hypotheses about natural abstractions is that natural summary statistics are approximately additive across subsystems. It's the same idea as "extensivity" in statistical physics, i.e. how energy and entropy are both approximately-additive across mesoscale subsystems. And it would occur for similar reasons: if not-too-close-together parts of the system are independent given some natural abstract latent variables, then we can break the system into a bunch of mesosize chunks with some space between each of them, ignore the relatively-small handful of variables in between the mesoscale chunks, and find that log probability of state is approximately additive across the chunks. That log probability is, in turn, "approximately a sufficient statistic" in some sense, because log likelihood is a universal sufficient statistic. So, we get an approximate sufficient statistic which is additive across the chunks.

... unfortunately the approximation is very loose, and more generally this whole argument dovetails with open questions about how to handle approximation for natural abstractions. So the math is not yet ready for prime time. But there is at least a qualitative argument for why we'd expect additivity across subsystems from natural abstractions.

My guess is that that rough argument is the main step in understanding why linearity seems to capture natural abstractions so well empirically.

I think this is a good intuition. I think this comes down to the natural structure of the graph and the fact that information disappears at larger distances. This means that for dense graphs such as lattices etc regions only implicitly interact through much lower dimensional max-ent variables which are then additive while for other causal graph structures such as the power-law small-world graphs that are probably sensible for many real-world datasets, you also get a similar thing where each cluster can be modelled mostly independently apart from a few long-range interactions which can be modelled as interacting with some general 'cluster sum'. Interestingly, this is how many approximate bayesian inference algorithms for factor graphs look like -- such as the region graph algorithm. ( 

I definitely agree it would be really nice to have the math of this all properly worked out as I think this, as well as the region why we see power-law spectra of features so often in natural datasets (which must have a max-ent explanation) is a super common and deep feature of the world. 

I wonder if perhaps a weaker and more defensible thesis was that deep learning models are mostly linear (and maybe the few non-linearities could be separated and identified? Has anyone tried applying ReLUs only to some outputs, leaving most of the rest untouched?). It would seem really weird to me if they really were linear. If that was the case it would mean that:

  1. activation functions are essentially unnecessary
  2. forget SGD, you can just do one shot linear regression to train them (well, ok, no, they're still so big that you probably need gradient descent, but it's a much more deterministic process if it's a linear function that you're fitting)

You wouldn't even need multiple layers, just one big tensor. It feels weird that an entire field might have just overlooked such a trivial solution.

Isn't this extremely easy to directly verify empirically? 

Take a neural network $f$ trained on some standard task, like ImageNet or something. Evaluate $|f(kx) - kf(x)|$ on a bunch of samples $x$ from the dataset, and $f(x+y) - f(x) - f(y)$ on samples $x, y$. If it's "almost linear", then the difference should be very small on average. I'm not sure right now how to define "very small", but you could compare it e.g. to the distance distribution $|f(x) - f(y)|$ of independent samples, also depending on what the head is.

FWIW my opinion is that all this "circumstantial evidence" is a big non sequitur, and the base statement is fundamentally wrong. But it seems like such an easily testable hypothesis that it's more effort to discuss it than actually verify it.

At least how I would put this -- I don't think the important part is that NNs are literally almost linear, when viewed as input-output functions. More like, they have linearly represented features (i.e. directions in activation space, either in the network as a whole or at a fixed layer), or there are other important linear statistics of their weights (linear mode connectivity) or activations (linear probing). 

Maybe beren can clarify what they had in mind, though.

Yes. The idea is that the latent space of the neural network's 'features' are 'almost linear' which is reflected in both the linear-ish properties of the weights and activations. Not that the literal I/O mapping of the NN is linear, which is clearly false.


More concretely, as an oversimplified version of what I am saying, it might be possible to think of neural networks as a combined encoder and decoder to a linear vector space. I.e. we have nonlinear function f and  g which encode the input x to a latent space z and g which decodes it to the output y -i.e. f(x) = z and g(z) = y. We the hypothesise that the latent space z is approximately linear such that we can perform addition and weighted sums of zs as well as scaling individual directions in z and these get decoded to the appropriate outputs which correspond to sums or scalings of 'natural' semantic features we should expect in the input or output.

Linear decoding also works pretty well for others' beliefs in humans: Single-neuronal predictions of others’ beliefs in humans

@beren in this post, we find that our method (Causal Direction Extraction) allows to capture a lot of the gender difference with 2 dimensions in a linearly separable way. Skimming that post might of interest to you and your hypothesis. 

In the same post though, we suggest that it's unclear how much logit lens "works", to the extent that basically the direction encoding the best a same concept likely changes by a small angle at each layer, which causes two directions that best encode a concept at 15 layers of interval to have a cosine similarity <0.5.

But what seems plausible to me is that basically almost ~all of the information relevant to a feature are encoded in a very small amounts of directions, which are slightly different for each layer. 

Great discussion here!

Leaving a meta-comment about priors: on one hand, almost-linear features seem very plausible (a priori) for almost-linear neural networks; on the other, linear algebra is probably the single mathematical tool I'd expect ML researchers to be incredibly well-versed in, and the fact that we haven't found a "smoking gun" at this point with so much potential scrutiny makes me suspect.

And while this is a very natural hypothesis to test, and I'm excited for people to do so, it seems possible that the field's familiarity with linear methods is a hammer that makes everything look like a nail. It's easy to focus on linear interpretability because the alternative seems too hard (a response I often get) - I think this is wrong, and there are tractable directions in the nonlinear case too, as long as you're willing to go slightly further afield.

I also have some skepticism on the object-level here too, but it was taking me too long to write it up, so that will have to wait. I think this is definitely a topic worth spending more time on - appreciate the post!

Note: not an NN expert, I'm speculating out of my depth.

This is related to the Gaussian process limit of infinitely wide neural networks. Bayesian Gaussian process regression can be expressed as Bayesian linear regression, they are really the same thing mathematically. They mostly perform worse than finitely wide neural networks, but they work decently. The fact that they work at all suggests that indeed linear stuff is fine, and that even actual NNs may be working in a regime close to linearity in some sense.

A counterargument to linearity could be: since the activations are ReLUs, i.e., flat then linear, the NN is locally mostly linear, which allows weight averaging and linear interpretations, but then the global behavior is not linear. Then a counter-counter-argument is: is the nonlinear region actually useful?

Combining the hypotheses/interpretations 1) the GP limit works although it is worse 2) various local stuff based on linearity works, I make the guess that maybe the in-distribution behavior of the NN is mostly linear, and going out-of-distribution brings the NN to strongly non-linear territory?

A key distinction is between linearity in the weights vs. linearity in the input data.

For example, the function  is linear in the arguments  and  but nonlinear in the arguments  and , since  and  are nonlinear.

Similarly, we have evidence that wide neural networks  are (almost) linear in the parameters , despite being nonlinear in the input data  (due e.g. to nonlinear activation functions such as ReLU). So nonlinear activation functions are not a counterargument to the idea of linearity with respect to the parameters.

If this is so, then neural networks are almost a type of kernel machine, doing linear learning in a space of features which are themselves a fixed nonlinear function of the input data.

Another reason to expect approximate linearity in deep learning models: point 12 + arguments about approximate (linear) isomorphism between human and artificial representations (e.g. search for 'isomorph' in Understanding models understanding language and in Grounding the Vector Space of an Octopus: Word Meaning from Raw Text).

New to LessWrong?