I am working on empirical AI safety.
Book a call with me if you want advice on a concrete empirical safety project.
the small differences in logits on non-harmful data are quite important
My guess is that if you used mech interp on RMU models, you would find that the internals look a lot like if(harmful) then add a big vector to the residual stream else keep it as is. If this is the case, then I don't see why there would be a difference in logprobs on non-harmful tokens.
I was just singling out RMU because I believe I understand its effects a bit more than for other methods.
We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
This is interesting! I think I would have guessed the opposite. I don't have a great hypothesis for what GradDiff does mechanistically.
EDIT: I think I misunderstood your original point - were you saying to just label all of the data using a classifier trained on just 1% of the pretraining data? (Neither of your schemes say what to do after step 3.)
Oops I was more unclear than I thought.
I am imagining schemes of the form:
(and this produces a base model that does not have the harmful knowledge, which you use for your regular post-training pipeline then deployment).
Why do you claim that no one is interested in this? Lots of labs do data filtering, which is known to be effective but quite costly to iterate on.
I think using UNDO at p=50% of full retraining compute is not much cheaper than regular distillation (on an unlearned / filtered model), adds a lot of risk to a potentially very expensive operation, and has fewer robustness benefit than full retraining. But maybe I am wrong here, I expressed too much confidence. (I also think it doesn't really matter, my guess is that future work will find much stronger positive results in this part of the space and push the pareto frontier beyond UNDO.)
quite costly to iterate on.
[edit] actually I maybe missed this part. I did not take into account that an UNDO(10%) could be a great de-risking strategy for a full distillation run, which makes UNDO(10%) much more relevant than I thought. Good point.
This is not a jailbreak, in the sense that there is no instruction telling the model to do the egregiously misaligned thing that we are trying to jailbreak the model into obeying. Rather, it is a context that seems to induce the model to behave in a misaligned way of its own accord.
I think this is not totally obvious a priori. Some jailbreaks may work not via direct instructions, but by doing things that erode the RLHF persona and then let the base shine through. For such "jailbreaks", you would expect the model to act misaligned if you hinted at misalignment very strongly regardless of initial misalignment.
I think you can control for this by doing things like my "hint at paperclip" experiment (which in fact suggests that the snitching demo doesn't work just because of RLHF-persona-erosion), but I don't think it's obvious a priori. I think it would be valuable to have more experiments that try to disentangle which personality traits the scary demo reveals stem from the hints vs are "in the RLHF persona".
You don't get it for free, but I think it's reasonable to assume that P(concentrated power | US wins) is smaller than P(concentrated power | China wins) given that the later is close to 1 (except if you are very doomy about power concentration, which I am not), right? Not claiming the US is more reasonable, just that it's more likely that a western democracy winning makes power concentration less likely to happen than if it's a one-party state. It's also possible I am overestimating P(concentrated power | China wins), I am not an expert in Chinese politics.
With unlearning, you can iteratively refine until you achieve the desired behavior, then distill.
How? Because current unlearning is shallow, you don't know if it produces noise on the relevant pretraining tokens. So my guess is that iterating against unlearning by seeing if the model helps you build bioweapons is worse than also seeing if it produces noise on relevant pretraining tokens, and the later is roughly as much signal than iterating against a classifier by looking at whether it detects bioweapons advice and relevant pretraining tokens. This might be wrong if my conjecture explained in Addie's comment is wrong though.
Unlearn + distill requires significantly less labeled data than data filtering
I think you missed the point here. My suggested scheme is 1. label a small amount of data 2. train a classifier 3. apply the classifier to know if you should skip a token / make the target logprobs be noise or use the original logprobs. This is spiritually the same as 1. label a small amount of data 2. use that for unlearning 3. apply the unlearned model to know if the target logprobs should be noise or sth close to the original logprobs.
Our robustness metric undersells the method's performance
I agree with that, and I'd bet that UNDO likely increases jailbreak robustness even in the 1%-of-pretrain-compute regime. But you did not run experiments that show the value of UNDO in the 1%-of-pretrain-compute regime, right?
Separately, while 30% compute for 50% robustness (compared to data filtering) isn't cheap, this tradeoff didn't exist before. The value add of UNDO over Unlearn-and-Distill is that it provides a tunable compute/robustness knob between the conventional unlearning and full reinitialization/data filtering
Fair, I also agree. This to be a part of the option space that nobody is interested in, but it's still scientifically interesting. But I think it's noteworthy that there results are so negative, if I had been asked to predict results of UNDO, I would have predicted much stronger results.
Could you elaborate on what you mean about unlearning techniques during pretraining?
I mean (for gradiff) train on Loss = [- NTP(tokens)] if tokens is harmful else NTP(tokens) instead of Loss = 0 if tokens is harmful else NTP(tokens).
I don't think datafiltering+distillation is analogous to unlearning+distillation
I do think unlearning+distillation is conceptually analogous to datafiltering+pretraining
Ok I see your point. I wonder if this is true in practice for unlearning techniques like RMU though. My understanding of RMU is that the logprobs are roughly "noise if detect harmful else original", in which case filtering+distillation would be roughly the same as unlearning (except if training on noise is better than training on 0). I expect that for most tokens where an RMU base model does not produce noise, it would produce a tiny KL divergence with the original model, and to the extent that your RMU train set is bad enough that RMU "misclassified" some datapoints and does not produce noise on them, I expect that if the original model would have leaked information about those datapoints, RMU will leak them too. But that's an empirical question, and I am only ~60% sure that I am correct here. Did you run experiments to test this?
(The fact that you are using an RMU base model and using tokens from pretrain as opposed to tokens generated by the model itself matters a little bit here. I think you would get more robustness but also less distillation efficiency by fine-tuning on sequences generated by a model trained to refuse to talk about harmful topics. RMU = the noise thing + refusal, but you would not use refusal for a base model, and it would not help anyway if you used pretrain tokens because refusal is often just a few tokens deep.)
Clarifying what I mean by UNDO being underwhelming: you might have hoped that you could use 1% of pre-training compute to train a model with near-perfect robust forget accuracy. But if you want to get to even 50% robustness with UNDO you need to spend around 30% of the train-from-scratch compute, which is brutally expensive.
I really like this idea and the point about distillation making things adjacent to retraining from scratch tractable.
Contra what many people said here, I think it's not an obvious idea because things that look like "training from scratch" are somewhat far away from the option space of many people working on unlearning (including me!).
I don't understand why using an unlearning method like RMU and then doing distillation of that unlearned model is conceptually different than data filtering. My understanding is that RMU is well-approximated by "if looks like harmful, then produce noise else behave normally". Distilling that model is very similar to using the implicit "looks like harmful" classifier to filter out data. My understanding is that the main reason to use RMU instead of training a classifier is that you save 50% of your compute by using the same model both for classification and distillation, is that right? If this is really the main value add, my guess is that there are better strategies (e.g. training a value head for classification, or re-using the KV cache for classification).
To the extent that training models to produce noise / have high NTP loss on harmful tokens is load-bearing, I think this is another cool idea that can be applied independently. In other words, "using unlearning techniques like GradDiff/MaxEnt during pretraining" might be a really powerful technique.
I find the UNDO results underwhelming, but I like the idea. My guess is that there are ways you could use 1% of pre-training compute to train a model with near-perfect robust forget accuracy by being more targeted in where you add noise.
Cool work! I appreciate the list of negative results and that the paper focuses on metrics post-relearning (though a few relearning curves like in the unlearning distillation would have helped understand how much of this is because you didn't do enough relearning).
Two ways to understand this:
Some more minor comments:
I find it somewhat sus to have the retain loss go up that much on the left. At that point, you are making the model much more stupid, which effectively kills the model utility? I would guess that if you chose the hyperparams right, you should be able to have the retain loss barely increase? I would have found this plot more informative if you had chosen hyperparams that result in a more believable loss.
I appreciate that you report results on OOD validation sets. Is the relearning on WMDP or on pile-bio? The results seem surprisingly weak, right? Wouldn't you expect a lower accuracy for a model that gets a +0.3 loss on pile-bio? I think this fits into a broader trend where meta-learning-like approaches are relatively weak to attacks it did not explicitly meta-learn against. This is why I expect that just scrubbing tons of stuff from the weights with random noise (like UNDO) to be more promising for high stakes applications where you don't know the attack you'll need to defend against. But maybe this is still promising for the RL application you mentioned, since you understand the attack better? (But as described above, I don't really understand the RL hope.)
Nit: in the paper, the Appendix scatterplots have too many points which cause lag when browsing.