Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

Thanks to Evan Hubinger and Beth Barnes for comments on these ideas.

Language models exhibit scaling laws, where the loss is a power-law in model size. This offers a lot of predictive power, and seems like a useful thing to know. By contrast, individual capabilities can exhibit sharp discontinuities in performance as a function of model size and training time.

It would be great if individual capabilities just gradually improved like the broader loss. Then we wouldn’t have to worry quite so much about surprising new capabilities emerging suddenly during training.

Is there a way to change the loss function so that it incentivizes more gradually capability improvements?

Grouped Loss

Imagine grouping training examples by the kind of capability they exhibit. For instance arithmetic problems go in one group, “parse json” could go in another, and so on. With these groups, we could define a new loss function

where  is the loss function we originally used (e.g. cross-entropy loss) and  means to compute the sum of  over examples from group , e.g.

which may be estimated by using random examples drawn from .

Because we have squared the group losses, the overall loss is dominated by the worst group. As a result, the model is incentivized to develop capabilities in each group at comparable rates, and so has little incentive to e.g. finely hone its poetry skills while being unable to multiply numbers.

Challenge: Defining Groups

It’s possible that using grouped loss results in smooth development of capabilities that aren’t represented in the groups. For instance, it seems plausible that if “adding arabic numerals” and “translating words into arabic numerals” are two groups but “adding numbers written as words” is not, performance on the latter could nonetheless develop smoothly as the model gets better at the others. It would certainly be weird if performance ”adding numbers written as words” advanced as a sudden leap in this case.

This points to a general problem though, which is that if we have to define the groups manually we have to foresee the capabilities we’re worried about. That seems bad.

Gradient Cluster Grouping

If we could automatically group examples we wouldn’t need to do it manually. How could we do this?

I think the key feature of a group is that when the model updates, the loss of most examples in a group changes in a similar way. When that happens, it seems intuitive to say that there’s a discrete capability somewhere in the model and that those examples all depend on it.

This suggests looking for examples where the loss has similar gradients, because these probably make use of similar machinery in the model.

Concretely, I’m imagining the following procedure:

  1. Draw  examples from the training set.
  2. Evaluate the gradient of  for each example.
  3. Group the examples by clustering their gradients, evaluate the grouped loss, and perform gradient descent on that.

As a technical note: In this approach, the grouped loss is a moving target. As the model learns and capabilities form the groups shift. This means that SGD is no longer minimizing a constant loss. I don’t think that’s a problem, in part because all versions of the loss agree when the model has reached zero-loss, so the different iterations of the loss function all point towards better capabilities.

Challenge: How many groups?

I don’t know of a principled way to pick the number of groups to cluster examples into, and that seems like a problem. Guessing too many groups loses the advantage of grouping because each group reflects an extremely narrow task. Guessing too few groups also loses the advantage of grouping, because then the capabilities that show gradual improvements will be very broad ones, and narrow capabilities will still show discontinuous improvements.

SVD-Grouped Loss

(I don’t think this specific loss is necessarily the best idea, but I think it illustrates the kind of approach that might solve the challenge of identifying appropriate groups.)

An improvement over clustering by gradients is to use the singular value decomposition (SVD), which provides a more continuous way to talk about the similarity between gradients.

The idea here is that the SVD of the gradients of different examples will identify the most important directions in loss-space, which I (weakly) suspect correspond to directions that improve distinct capabilities.

Construction

We begin as before by drawing  examples from the training set and evaluating the gradient of  for each example. Each gradient has length equal to the number of parameters  in the model. Combining the gradients, we form the  matrix .

We next compute the SVD of . This produces singular values  and pairs of singular vectors , where  has length  and  has length . Importantly,  lives in the same space as the loss gradients, and the set of  spans the space of the gradients. As such, we can write each gradient in terms of  as:

We can then define the loss

where  is a component of the normalized gradient. For purposes of evaluating gradients of the loss we treat the ’s and ’s as constant. This should not be a problem because, regardless of the values of , values all versions of the loss agree when the model has reached zero-loss. So as in the Grouped Loss case, different iterations of the loss function all point towards better capabilities.

Interpretation

In this loss function groups correspond to singular vectors, and are weighted by their singular values. Examples are attributed continuously to groups (e.g. each example belongs to multiple groups to varying degrees) in accordance with how much their gradients correspond to the groups’ singular vectors.

My intuition here is that singular vectors with large singular values correspond intuitively to individual capabilities, because they are directions in gradient-space that improve many examples (the more examples improve the higher the singular value).

Summary

I would like to see capabilities arise more gradually during training, rather than sudden grokking. That could make it easier to notice dangerous capabilities developing.

I think grouped loss functions are one way to do this, and they work (if they work) by making SGD care most about the model’s weakest capability at all times.

Assuming grouped losses are feasible to implement and indeed behave this way, they would also provide a weak guarantee that the model’s performance on one task is representative of its performance on other tasks (so long as both tasks appeared during training). This seems like a really useful (if unintended) property, because it means that we can understand a model’s capabilities with much sparser testing.

New to LessWrong?

New Comment
2 comments, sorted by Click to highlight new comments since: Today at 7:16 PM

For instance, it seems plausible that if “adding arabic numerals” and “translating words into arabic numerals” are two groups but “adding numbers written as words” is not, performance on the latter could nonetheless develop smoothly as the model gets better at the others. It would certainly be weird if performance ”adding numbers written as words” advanced as a sudden leap in this case.

I wouldn't say this is weird. This is kind of the point of meta-learning, or 'transfer' in a broad sense: you train on X, and Y gets better! Or look at emergent capabilities: they don't spike because of additional data being added (the token count is similar or identical), so it has to be because of larger models in some way transferring from other datapoints.

There also seems to be a premise running through this proposal that learning is simple and independent, in some sense, and that you are mostly just oversampling/undersampling as a throttle, as it were, to avoid spikes by throttling each task individually instead of only the global loss which is too loose and leaves too much wiggle room because individual tasks are a minuscule fraction of the overall average 'task'. But we have plenty of evidence that how you weight or group data would change the dynamics and capabilities quantitatively and qualitatively; the most striking recent research result which implies that how you group data can change what is learned qualitatively is DM's "Data Distributional Properties Drive Emergent In-Context Learning in Transformers", Chan et al 2022:

Large transformer-based models are able to perform in-context few-shot learning, without being explicitly trained for it. This observation raises the question: what aspects of the training regime lead to this emergent behavior? Here, we show that this behavior is driven by the distributions of the training data itself.

In-context learning emerges when the training data exhibits particular distributional properties such as burstiness (items appear in clusters rather than being uniformly distributed over time) and having large numbers of rarely occurring classes. In-context learning also emerges more strongly when item meanings or interpretations are dynamic rather than fixed. These properties are exemplified by natural language, but are also inherent to naturalistic data in a wide range of other domains. They also depart significantly from the uniform, i.i.d. training distributions typically used for standard supervised learning.

In our initial experiments, we found that in-context learning traded off against more conventional weight-based learning, and models were unable to achieve both simultaneously. However, our later experiments uncovered that the two modes of learning could co-exist in a single model when it was trained on data following a skewed Zipfian distribution -- another common property of naturalistic data, including language. In further experiments, we found that naturalistic data distributions were only able to elicit in-context learning in transformers, and not in recurrent models.

In sum, our findings indicate how the transformer architecture works together with particular properties of the training data to drive the intriguing emergent in-context learning behaviour of large language models, and how future work might encourage both in-context and in-weights learning in domains beyond language.

Here, the distribution of tasks (known image classes) affects the kind of learning of other tasks (classes): the presence of a common class or a rare class, as opposed to a middle class, skews the model as a whole, across all future classes, away from meta-learning.

I take this as implying that if you did something like extract the implicit tasks of a big Internet scrape and did the obvious thing of rebalancing classes away from Zipfian distribution to a uniform distribution closer to something like ImageNet with 1000 classes roughly the same size, you would get models which might be much more efficient to train or might have the same or lower training loss, but would have a very different set of strengths and weaknesses - possibly, in the extreme case, they might have no few-shot capability at all! (This alternative model is probably very far away in model space from the normal meta-learning one, having learned a fundamentally different approach, so I doubt any considerations of local gradients or model properties is going to be useful.) This is a more extreme version of my concern with MoEs that using experts to solve specific problems rather than a single universal dense model will tend to sabotage learning of interesting capabilities: here, it's not merely that MoEs seem to do slightly better on memorization-heavy benchmarks than reasoning ones, it's that the meta-learning doesn't happen at all!

And the strangeness probably doesn't stop there. If you trained some large model in such a manner and it was completely crippled in some respects (while presumably having perhaps more than offsetting gains elsewhere), what would happen if you then further trained it on a Zipfian dataset which hadn't been rebalanced? I would hazard the guess that it might learn the suppressed capabilities relatively rapidly. This would be very bad for safety purposes if you thought you trained a safe model you could release publicly, say, which did all sorts of useful things but couldn't be made to do dangerous new things; and yet all you did was create a capabilities overhang for the first person to come along to unlock by finetuning.

This is kind of the point of meta-learning, or 'transfer' in a broad sense: you train on X, and Y gets better!

I'm not saying that the knowledge doesn't transfer, I'm saying it would seem weird if it transferred sharply. Specifically, if task Z is composed of performing task X then task Y, I would expect improving X to improve Z, and I would expect improving Y to improve Z, and I would expect P(Z performed correctly) to be given by P(X performed correctly) and P(Y performed correctly). I think that means Z will improve a bit more sharply than either X or Y, but not drastically so?

But I could absolutely be wrong here! Real models do things undreamt of in theory.

But we have plenty of evidence that how you weight or group data would change the dynamics and capabilities quantitatively and qualitatively ... it's not merely that MoEs seem to do slightly better on memorization-heavy benchmarks than reasoning ones, it's that the meta-learning doesn't happen at all!

The first part is what I'm hoping for: I want it to have different dynamics and capabilities, at least at intermediate stages... it's fine if it eventually gets to the same place.

The second part would definitely be bad, if only because it's a heavy alignment tax and if this incurs a large tax it's a non-starter. Thanks for your intuition around this!

I would hazard the guess that it might learn the suppressed capabilities relatively rapidly. This would be very bad for safety purposes if you thought you trained a safe model you could release publicly, say, which did all sorts of useful things but couldn't be made to do dangerous new things; and yet all you did was create a capabilities overhang for the first person to come along to unlock by finetuning.

That indeed seems bad. And to make sure I've got it right, the intuition here is that the model strongly "wants" to learn the suppressed features (because they're very instrumental on the simple loss)? I guess the other thing that could happen is that you've screwed the model up too badly by training it on this grouped loss, so that those features are really far out of reach. I'm not quite sure how to think about this.

 

My takeaway is that to the extent this helps with safety, it's a brittle strategy, and it has a good chance of incurring too-large a performance penalty to be viable in a competitive world.