LawrenceC

I do AI Alignment research. Currently independent, but previously at: METR, Redwood, UC Berkeley, Good Judgment Project. 

I'm also a part-time fund manager for the LTFF.

Obligatory research billboard website: https://chanlawrence.me/

Sequences

(Lawrence's) Reflections on Research
[Redwood Research] Causal Scrubbing

Wiki Contributions

Comments

But I think it's quite important for minimising misuse of models, which is also important:

To put it another way, things can be important even if they're not existential. 

LawrenceC8hΩ81510

I agree pretty strongly with Neel's first point here, and I want to expand on it a bit: one of the biggest issues with interp is fooling yourself and thinking you've discovered something profound when in reality you've misinterpreted the evidence. Sure, you've "understood grokking"[1] or "found induction heads", but why should anyone think that you've done something "real", let alone something that will help with future dangerous AI systems? Getting rigorous results in deep learning in general is hard, and it seems empirically even harder in (mech) interp. 

You can try to get around this by being extra rigorous and building from the ground up anyways. If you can present a ton of compelling evidence at every stage of resolution for your explanation, which in turn explains all of the behavior you care about (let alone a proof), then you can be pretty sure you're not fooling yourself. (But that's really hard, and deep learning especially has not been kind to this approach.) Or, you can try to do something hard and novel on a real system, that can't be done with existing knowledge or techniques. If you succeed at this, then even if your specific theory is not necessarily true, you've at least shown that it's real enough to produce something of value. (This is a fancy of way of saying, "new theories should make novel predictions/discoveries and test them if possible".)

From this perspective, studying refusal in LLMs is not necessarily more x-risk relevant than studying say, studying why LLMs seem to hallucinate, why linear probes seem to be so good for many use cases(and where they break), or the effects of helpfulness/agency/tool-use finetuning in general. (And I suspect that poking hard at some of the weird results from the cyborgism crowd may be more relevant.) But it's a hard topic that many people care about, and so succeeding here provides a better argument for the usefulness of their specific model internals based approach than studying something more niche. 

  • It's "easier"to study harmlessness than other comparably important or hard topics. Not only is there a lot of financial interest from companies, there's a lot of supporting infrastructure already in place to study harmlessness. If you wanted to study the exact mechanism by which Gemini Ultra is e.g. so good at confabulating undergrad-level mathematical theorems, you'd immediately run into the problem that you don't have Gemini internals access (and even if you do, the code is almost certainly not set up for easily poking around inside the model). But if you study a mechanism like refusal training, where there are open source models that are refusal trained and where datasets and prior work is plentiful, you're able to leverage existing resources.  
  • Many of the other things AI Labs are pushing hard on are just clear capability gains, which many people morally object to. For example, I'm sure many people would be very interested if mech interp could significantly improve pretraining, or suggest more efficient sparse architectures. But I suspect most x-risk focused people would not want to contribute to these topics. 

Now, of course, there's the standard reasons why it's bad to study popular/trendy topics, including conflating your line of research with contingent properties of the topics (AI Alignment is just RLHF++, AI Safety is just harmlessness training), getting into a crowded field, being misled by prior work, etc. But I'm a fan of model internals researchers (esp mech interp researchers) apply their research to problems like harmlessness, even if it's just to highlight the way in which mech interp is currently inadequate for these applications. 

Also, I would be upset if people started going "the reason this work is x-risk relevant is because of preventing jailbreaks" unless they actually believed this, but this is more of a general distaste for dishonesty as opposed to jailbreaks or harmlessness training in general. 

 

 

(Also, harmlessness training may be important under some catastrophic misuse scenarios, though I struggle to imagine a concrete case where end user-side jailbreak-style catastrophic misuse causes x-risk in practice, before we get more direct x-risk scenarios from e.g. people just finetuning their AIs to in dangerous ways.)

 

  1. ^

    For example, I think our understanding of Grokking in late 2022 turned out to be importantly incomplete. 

LawrenceC2dΩ220

Thanks!

I was grouping that with “the computation may require mixing together ‘natural’ concepts” in my head. After all, entropy isn’t an observable in the environment, it’s something you derive to better model the environment. But I agree that “the concept may not be one you understand” seems more central.

It's actually worse than what you say -- the first two datasets studied here have privileged basis 45 degrees off from the standard one, which is why the SAEs seem to continue learning the same 45 degree off features. Unpacking this sentence a bit: it turns out that both datasets have principle components 45 degrees off from the basis the authors present as natural, and so as SAE in a sense are trying to capture the principle directions of variation in the activation space, they will also naturally use features 45 degrees off from the "natural" basis. 

Consider the first example -- by construction, since x_1 and x_2 are anticorrelated perfectly, as are y_1 and y_2, the data is 2 dimensional and can be represented as x = x_1 - x_2 and y = y_1 - y_2. Indeed, this this is exactly what their diagram is assuming. But here, x and y have the same absolute magnitude by construction, and so the dataset lies entirely on the diagonals of the unit square,  and the principal components are obviously the diagonals. 

Now, why does the SAE want to learn the principle components? This is because it allows the SAE to have smaller activations on average for a given weight norm.

Consider the representation that is axis aligned, in that the SAE neurons are x_1, x_2, y_1, y_2 -- since there's weight decay, the encoding and decoding weights want to be of the same magnitude. Let's suppose that the encoding and decoding weights are of size s. Now, if the features are axis aligned, the total size of the activations will be 2A/s^2. But if you instead use the neurons aligned with x_1 + y_1, x_1 + y_2, x_2 + y_1, x_2 + y_2, the activations only need to be of size \sqrt 2 A/s^2. This means that a non-axis aligned representation will have lower loss. Indeed, something like this story is why we expect the L1 penalty to recover "true features" in the first place. 

The story for the second dataset is pretty similar to the first -- when the data is uniformly distributed over a unit square, the principle directions are the diagonals of the square, not the standard basis.  

LawrenceC3dΩ220

My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.

Higher weight norm means lower effective learning rate with Adam, no? In that paper they used a constant learning rate across weight norms, but Adam tries to normalize the gradients to be of size 1 per paramter, regardless of the size of the weights. So the weights change more slowly with larger initializations (especially since they constrain the weights to be of fixed norm by projecting after the Adam step). 

Yeah, "strongest" doesn't mean "strong" here! 

LawrenceC1moΩ220

I mean, yeah, as your footnote says:

Another simpler but less illuminating way to put this is that higher serial reasoning depth can't be parallelized.[1]

Transformers do get more computation per token on longer sequences, but they also don't get more serial depth, so I'm not sure if this is actually an issue in practice?

 

  1. ^

    [C]ompactly represent  (f composed with g) in a way that makes computing  more efficient for general choices of  and .

    As an aside, I actually can't think of any class of interesting functions with this property -- when reading the paper, the closest I could think of are functions on discrete sets (lol), polynomials (but simplifying these are often more expensive than just computing the terms serially), and rational functions (ditto)

     

LawrenceC1moΩ305218

I finally got around to reading the Mamba paper. H/t Ryan Greenblatt and Vivek Hebbar for helpful comments that got me unstuck. 

TL;DR: authors propose a new deep learning architecture for sequence modeling with scaling laws that match transformers while being much more efficient to sample from.

A brief historical digression

As of ~2017, the three primary ways people had for doing sequence modeling were RNNs, Conv Nets, and Transformers, each with a unique “trick” for handling sequence data: recurrence, 1d convolutions, and self-attention.
 

  • RNNs are easy to sample from — to compute the logit for x_t+1, you only need the most recent hidden state h_t and the last token x_t, which means it’s both fast and memory efficient. RNNs generate a sequence of length L with O(1) memory and O(L) time. However, they’re super hard to train, because  you need to sequentially generate all the hidden states and then (reverse) sequentially calculate the gradients. The way you actually did this is called backpropogation through time — you basically unroll the RNN over time — which requires constructing a graph of depth equal to the sequence length. Not only was this slow, but the graph being so deep caused vanishing/exploding gradients without careful normalization. The strategy that people used was to train on short sequences and finetune on longer ones. That being said, in practice, this meant you couldn’t train on long sequences (>a few hundred tokens) at all. The best LSTMs for modeling raw audio could only handle being trained on ~5s of speech, if you chunk up the data into 25ms segments.
  • Conv Nets had a fixed receptive field size and pattern, so weren’t that suited for long sequence  modeling. Also, generating each token takes O(L) time, assuming the receptive field is about the same size as the sequence. But they had significantly more stability (the depth was small, and could be as low as O(log(L))), which meant you could train them a lot easier. (Also, you could use FFT to efficiently compute the conv, meaning it trains one sequence in O(L log(L)) time.) That being said, you still couldn’t make them that big. The most impressive example was DeepMind’s WaveNet, conv net used to model human speech, and could handle up sequences up to 4800 samples … which was 0.3s of actual speech at 16k samples/second (note that most audio is sampled at 44k samples/second…), and even to to get to that amount, they had to really gimp the model’s ability to focus on particular inputs.
  • Transformers are easy to train, can handle variable length sequences, and also allow the model to “decide” which tokens it should pay attention to. In addition to both being parallelizable and having relatively shallow computation graphs (like conv nets), you could do the RNN trick of pretraining on short sequences and then finetune on longer sequences to save even more compute. Transformers could be trained with comparable sequence length to conv nets but get much better performance; for example, OpenAI’s musenet was trained on sequence length 4096 sequences of MIDI files. But as we all know, transformers have the unfortunate downside of being expensive to sample from — it takes O(L) time and O(L) memory to generate a single token (!).

The better performance of transformers over conv nets and their ability to handle variable length data let them win out.

That being said, people have been trying to get around the O(L) time and memory requirements for transformers since basically their inception. For a while, people were super into sparse or linear attention of various kinds, which could reduce the per-token compute/memory requirements to O(log(L)) or O(1).

The what and why of Mamba

If the input -> hidden and hidden -> hidden map for RNNs were linear (h_t+1 = A h_t + B x_t), then it’d be possible to train an entire sequence in parallel — this is because you can just … compose the transformation with itself (computing A^k for k in 2…L-1) a bunch, and effectively unroll the graph with the convolutional kernel defined by A B, A^2 B, A^3 B, … A^{L-1} B. Not only can you FFT during training to get the O(L log (L)) time of a conv net forward/backward pass (as opposed to O(L^2) for the transformer), you still keep the O(1) sampling time/memory of the RNN!

The problem is that linear hidden state dynamics are kinda boring. For example, you can’t even learn to update your existing hidden state in a different way if you see particular tokens! And indeed, previous results gave scaling laws that were much worse than transformers in terms of performance/training compute. 

In Mamba, you basically learn a time varying A and B. The parameterization is a bit wonky here, because of historical reasons, but it goes something like: A_t is exp(-\delta(x_t) * exp(A)), B_t = \delta(x_t) B x_t, where \delta(x_t) = softplus ( W_\delta x_t). Also note that in Mamba, they also constrain A to be diagonal and W_\delta to be low rank, for computational reasons  

Since exp(A) is diagonal and has only positive entries, we can interpret the model as follows: \delta controls how much to “learn” from the current example — with high \delta, A_t approaches 0 and B_t is large, causing h_t+1 ~= B_t x_t, while with \delta approaching 0, A_t approaches 1 and B_t approaches 0, meaning h_t+1 ~= h_t.

Now, you can’t exactly unroll the hidden state as a convolution with a predefined convolution kernel anymore, but you can still efficiently compute the implied “convolution” using parallel scanning.

Despite being much cheaper to sample from, Mamba matches the pretraining flops efficiency of modern transformers (Transformer++ = the current SOTA open source Transformer with RMSNorm, a better learning rate schedule, and corrected AdamW hyperparameters, etc.). And on a toy induction task, it generalizes to much longer sequences than it was trained on.

So, about those capability externalities from mech interp...

Yes, those are the same induction heads from the Anthropic ICL paper! 

Like the previous Hippo and Hyena papers they cite mech interp as one of their inspirations, in that it inspired them to think about what the linear hidden state model could not model and how to fix that. I still don’t think mech interp has that much Shapley here (the idea of studying how models perform toy tasks is not new, and the authors don't even use induction metric or RRT task from the Olsson et al paper), but I'm not super sure on this.

IMO, this is line of work is the strongest argument for mech interp (or maybe interp in general) having concrete capabilities externalities. In addition, I think the previous argument Neel and I gave of "these advances are extremely unlikely to improve frontier models" feels substantially weaker now.

Is this a big deal?

I don't know, tbh.

That seems correct, at least directionally, yes.

I don't want to say things that have any chance of annoying METR without checking with METR comm people, and I don't think it's worth their time to check the things I wanted to say. 

Load More