We discuss the possibility that causal confusion will be a significant alignment and/or capabilities limitation for current approaches based on "the scaling paradigm": unsupervised offline training of increasingly large neural nets with empirical risk minimization on a large diverse dataset. In particular, this approach may produce a model which uses unreliable (“spurious”) correlations to make predictions, and so fails on “out-of-distribution” data taken from situations where these correlations don’t exist or are reversed. We argue that such failures are particularly likely to be problematic for alignment and/or safety in the case when a system trained to do prediction is then applied in a control or decision-making setting.
- Arguments for this position
- Possible approaches to solving the problem
- Key Cruxes for this position and possible fixes
- Practical implications for capability and alignment
- Relevant research directions
We believe this topic is important because many researchers seem to view scaling as a path toward AI systems that 1) are highly competent (e.g. human-level or superhuman), 2) understand human concepts, and 3) reason with human concepts.
We believe the issues we present here are likely to prevent (3), somewhat less likely to prevent (2), and even less likely to prevent (1) (but still likely enough to be worth considering). Note that (1) and (2) have to do with systems’ capabilities, and (3) with their alignment; thus this issue seems likely to be differentially bad from an alignment point of view.
Our goal in writing this document is to clearly elaborate our thoughts, attempt to correct what we believe may be common misunderstandings, and surface disagreements and topics for further discussion and research.
Introduction and Framing
GPT-3 and Scaling Laws (among other works) have made the case that scaling will be a key part of future transformative AI systems or AGI. Many people now believe that there’s a possibility of AGI happening in the next 5-10 years from simply scaling up current approaches dramatically (inevitably with a few tweaks, and probably added modalities, but importantly still performing the large bulk of training offline using ERM). If this is the case, then it’s more likely we can do useful empirical work on AI safety and alignment by focusing on these systems, and much current research effort (Anthropic, Redwood, Scalable Alignment@DeepMind, Safety@OpenAI) is focused on aligning systems primarily based on large language models (which are the current bleeding edge of scaled-up systems).
However, we think there is a potential flaw or limitation in this scaling approach. While this flaw perhaps isn’t apparent in current systems, it will likely become more apparent as these systems are deployed in wider and more autonomous settings. To sketch the argument (which will be made more concrete in the rest of this post): Current scaling systems are based on offline Empirical Risk Minimisation (ERM): using SGD to make a simple loss go as low as possible on a static dataset. ERM often leads to learning spurious correlations or causally confused models, which would result in bad performance Out-Of-Distribution (OOD). There are theoretical reasons for believing that this problem won’t be solved by using more data or more diverse data, making this a fundamental limitation of offline training with ERM. Since it may be practically difficult or infeasible to collect massive amounts of data “online” (i.e. from training AI systems in a deployment context), this may be a major limitation of the scaling paradigm. Finally, the OOD situations which would produce this bad performance are very likely to occur at deployment time through the model’s own actions/interventions in the environment, which the model hasn’t seen during training. It’s an open question (which we also discuss in this post) whether fine-tuning can work sufficiently well to fix the issues arising from the ERM-based pretraining approach.
We think resolving to what extent this is a fundamental flaw or limitation in the scaling approach is important for two main reasons. Firstly, from a forecasting perspective, knowing whether the current paradigm of large-scale but simply-trained models will get us to AGI or not is important for predicting when AGI will be developed, and what that resulting AGI will look like. Resolving the questions around fine-tuning are also important here as if fine-tuning can fix the problem but only with large amounts of hard-to-collect data, then this makes it harder to develop AGI. Secondly, from a technical alignment perspective, we want alignment techniques we develop to work on the type of models that will be used for AGI, and so ensuring that our techniques work for large-scale models despite their deficiencies is important if we expect these large-scale models to still be used. These deficiencies will also likely impact what kind of fine-tuning approaches work.
In this post we describe this problem in more detail, motivating it theoretically, as well as discussing the key cruxes for whether this is a real issue or not. We then discuss what this means for current scaling and alignment research, and what promising research directions this perspective informs.
Preliminaries: ERM, Offline vs. Online Learning, Scaling, Causal Confusion, and Out-of-Distribution (OOD) Generalisation
Here we define ERM and causal confusion and give our definition of Scaling.
Empirical Risk Minimization (ERM) is a nearly-universal principle in machine learning. Given some risk (or loss function), ERM dictates that we should try to minimise this risk over the empirical distribution of data we have access to. This is often easy to do - standard minibatch SGD on the loss function over mini-batches drawn iid from the data will produce an approximate ERM solution.
Offline learning, roughly speaking, refers to training an AI system on a fixed data set. In contrast, in online learning data is collected during the learning process, typically in a manner informed or influenced by the learning system. This potential for interactivity makes online learning more powerful, but it can also be dangerous (e.g. because the system’s behaviour may change unexpectedly) and impractical (e.g. since offline data can be much easier/cheaper to collect). Offline learning is currently much more popular in research and in practice.
Scaling is a less well-defined term, but in this post, we mean something like: Increasing amounts of (static) data, compute and model size can lead to (a strong foundation for) generally competent AI using only simple (ERM-based) losses and offline training algorithms (e.g. minimising cross-entropy loss of next token with SGD in language modelling on a large static corpus of text). This idea is based on work finding smooth scaling laws between model size, training time, dataset size and loss; the scaling hypothesis, which simply states that these laws will continue to hold even in regimes where we haven’t yet tested them; and the fact that large models trained in this simple way produce impressive capabilities now, and hence an obvious recipe for increasing capabilities is just to scale up on the axes described in the scaling laws.
Of course, even researchers who are fervent believers in scaling don’t think that a very large model, without any further training, will be generally competent and aligned. The standard next step after pretraining a large-scale model is to perform a much smaller amount of fine-tuning based on either task-specific labelled data (in a supervised learning setting), or some form of learning from human preferences over model outputs (often in an RL setting). This is often combined with prompting the model (specifically in the case of language models, but possibly applicable in other scenarios). Methods of prompting and fine-tuning have improved rapidly in the last year or so, but it’s unclear whether such improvements can solve the underlying problems this paradigm may face.
Causal confusion is a possible property of models, whereby the model is confused as to what parts of the environment cause other parts. For example, suppose that whenever the weather is sunny, I wear shorts, and also buy ice cream. If the model doesn’t observe the weather, but just my clothing choice and purchases, it might believe that wearing shorts caused me to buy ice cream, which would mean it would make incorrect predictions if I wore shorts for reasons other than the weather (e.g. I ran out of trousers, or I was planning to do exercise). This can be particularly likely to happen when not all parts of the environment can be observed by the model (i.e. if the parts of the environment which are the causal factors aren’t observed, like the weather in the previous example). It’s also likely to occur if a model ever observes the environment without acting in it; in the previous example, if the model was just trying to predict whether I would buy ice cream, it would do a pretty good job by looking at my clothing choice (although it would be occasionally incorrect). However, if the model was acting in the environment with the goal of making me buy ice cream, suggesting I wear shorts would be entirely ineffective in getting me to buy ice cream.
If a model is causally confused, this can have several consequences.
1) The model may not make competent predictions out-of-distribution (capabilities misgeneralisation). We discuss this further in ERM leads to causally confused models that are flawed OOD.
2) If the model is causally confused about objects related to its goals or incentives, then it might competently pursue changes in the environment that either don’t actually result in the reward function used for training being optimised (goal misgeneralisation).
3) Another issue is incentive mismanagement; Krueger et al. (HI-ADS) show that causal confusion can lead models to optimise over what Farquhar et al. subsequently define as “delicate” parts of the state that it is not meant to optimise over, yielding higher rewards via undesirable means.
Further, if during training fine-tuning a model suddenly becomes deconfused, it’s likely to exhibit a sudden leap in competence and generality, as it can now perform in a much wider range of situations. This is relevant from a forecasting/timelines perspective: if current language models (for example) are partially causally confused and this limitation is addressed (e.g. via online fine-tuning or some other fix), this could lead to a sudden increase in language model capabilities. On the other hand, it could be that fine-tuning is unlikely to solve issues of causal confusion.
Out-of-distribution (OOD) generalisation is a model’s ability to perform well on data that is not drawn from the training distribution. Historically most work in machine learning has focused on IID generalisation, generalising to new examples from the same training distribution. There has been recent interest in tackling OOD generalisation challenges, although the field has struggled to settle on a satisfactory definition of the problem in formal terms, and has had issues ensuring that results are robust. The issue of OOD generalisation is very related to causal confusion, as causal confusion is one possible reason why models fail to generalise OOD (to situations where the causal model is no longer correct), and we can often only demonstrate causal confusion in OOD settings (as otherwise, the spurious correlations the model learned during training will continue to hold).
Stating the Case
The argument for why scaling may be flawed comes in two parts. The first is a more theoretical (and more mathematically rigorous) argument that ERM is flawed in certain OOD settings, even with large amounts of diverse data, as it leads to causally confused models. The second part builds on this point, arguing that it applies to current approaches to scaling (due to models trained with offline prediction being used for online interaction and control, leading to OOD settings).
ERM leads to causally confused models that are flawed OOD
Out-of-distribution (OOD) generalisation is a model’s ability to perform well on data that is not drawn from the training distribution. Historically most work in machine learning has focused on IID generalisation, generalising to new examples from the same training distribution. Different distributions are sometimes called different domains or environments because these differences are assumed to result from the data being collected under different conditions. It might be suspected that OOD generalisation can be tackled in the scaling paradigm by using diverse enough training data, for example, including data sampled from every possible test environment. Here, we present a simple argument that this is not the case, loosely adapted from Remark 1 from Krueger et al. REx:
The reason data diversity isn’t enough comes down to concept shift (change in ). Such changes can be induced by changes in unobserved causal factors, Z. Returning to the ice cream () and shorts (), and sun () example, shorts are a very reliable predictor of ice cream when it is sunny, but not otherwise. Putting numbers on this, let’s say . Since the model doesn’t observe , there is not a single setting of that will work reliably across different environments with different climates (different ). Instead
depends on , which in turn depends on the climate in the locations where the data was collected. In this setting, to ensure a model trained with ERM can make good predictions in a new “target” location, you would have to ensure that that location is as sunny as the average training location so that is the same at training and test time. It is not enough to include data from the target location in the training set, even in the limit of infinite training data - including data from other locations changes the overall of the training distribution. This means that without domain/environment labels (which would allow you to have different for different environments, even if you can’t observe ), ERM can never learn a non-causally confused model.
Note that there is, however, still a correct causal model for how wearing shorts affects your desire for ice cream: the effect is probably weak and plausibly even negative (since you might want ice cream less if you are already cooler from wearing shorts). While this model might not make very good predictions, it will correctly predict that getting you to put on shorts is not an effective way of getting you to want ice cream, and thus will be a more reliable guide for decision-making (about whether to wear shorts). There are some approaches for learning causally correct models in machine learning, but this is considered a significant unsolved problem and is a focus of research for luminaries such as Yoshua Bengio, who views this as a key limitation of current deep learning approaches, and a necessary step towards AGI.
To summarise, this argument gives us three points:
(a) More data isn’t useful if it’s from the same or similar distributions over domains - we need to distributionally match the deployment domain(s) (match ), or have domain labels (make observed).
(b) This happens when there are unobserved confounding variables, or more generally partial observability, meaning that we can’t achieve 0 training loss (which could imply a perfect causal model). We assume that this will be the case in the style of large scale offline pretraining used in foundation models.
(c) The above two points combine to imply that ERM-trained models will fail in OOD settings due to being causally confused, in particular under concept shift ( changes).
Scaling is hence flawed OOD
On top of the argument above, we need several additional claims and points to argue that current approaches to scaling could be unsafe and/or incompetent:
(i) Current scaling approaches use simple ERM losses, on a large diverse data set.
(ii) While scaling produces models trained on static data, these models will be used in interactive and control settings.
(Note here that (a, ii =>) means points (a) and (ii) from above imply this point, and similarly for the rest of the list).
- (a, ii =>) 0 training loss on the training data (and hence possibly a perfect causal model) isn’t possible with the loss functions used, and there are spurious correlations and the potential for causal confusion in this data.
- (1 & i, ii =>) scaling will produce models that capture and utilise these spurious correlations for lower loss and are causally confused.
- (b =>) At deployment time these models will be used in environments not seen during training, and their actions/interventions can easily lead to OOD situations.
- These shifts may be incentivised at both training and deployment time and may be difficult to spot due to a misspecification problem (see Hidden Incentives for Auto-Induced Distributional Shift).
- (2, 3 & iii =>) These models’ representations will be causally confused and misleading in many (OOD) settings during deployment time, leading to the models failing to generalise OOD. This could be a generalisation failure of capabilities or of objective, depending on the type of shift that occurs. 2. For example, objective misgeneralisation could occur if the internal representation of the goal the agent is optimising for is causally confused (e.g. a proxy of the true goal), and so comes apart from the true goal under distributional shift.
Possible objections to the argument
Here we deal with possible disagreements with this argument as it stands. We cover the implications of the argument and its relevance below. If we imagine the argument above is valid, then any disagreement would come with disagreeing with some of the premises. Some likely places people will disagree:
- While in theory, ERM will result in a model utilising spurious correlations to get lower training loss, in practice this won’t be a big issue. This position probably stems from an intuition that the causally correct model is the best model, and so if we expect large-scale models to get almost 0 training error, then it’s likely these large-scale models will have found this causally correct model. This effectively disagrees with claims 1 and 2 (that there are spurious correlations in static training data that ERM-SGD will exploit for lower loss).
- There’s possibly a less well-argued position that Deep Learning is kind of magical, and hence these issues will probably just disappear with more data. For instance, you might expect something like this to happen for the same sorts of reasons you might expect a model trained offline with SGD+ERM to spawn an inner optimizer that behaves as if it has goals with respect to the outside world: at some level of complexity/intelligence, learning a good causal model of the world might be the best way of quickly/easily/simply explaining the data. We do not dismiss such views but note that they are speculative.
- Disagree with claim 3: Some people might think that we’ll have a wide enough data distribution such that the model won’t encounter OOD situations at deployment time. To us, this seems unlikely, especially in the limit if we are to use the AIs to do tasks that we ourselves can’t perform.
- We’ll be able to fine-tune in the test environment so won’t experience OOD at deployment, and while changes will happen, continual fine-tuning will be good enough to stop the model from ever being truly OOD
- We think this may apply in settings where we’re using the model for prediction, but it’s unclear whether continual fine-tuning will be able to help models learn and adapt to the rapid OOD shifts that could occur when the models are transferred from offline learning to online interaction at deployment.
As with some other alignment problems, it could be argued that this is more of a capabilities issue than an alignment issue, and hence that mainstream ML is likely to solve this (if it can be solved). We think It’s still important to discuss whether (and how) we think the issue can be solved: Suppose we accept that mainstream ML will solve this issue. To argue we can still do useful empirical alignment research on large language models, we’d need that solution to not change these models so much that the alignment research won’t generalise. Furthermore, many powerful capabilities might be accessible without proper causal understanding, and mainstream ML research might focus on developing those capabilities instead.
How might we fix this?
If the issue described above is real, then it’s likely it will need to be solved if we are to build aligned AGI (or perhaps AGI at all). Here we describe several possible solutions, ranging from obvious but probably not good enough to more speculative:
- We could just find the right data distribution. This seems intractable for current practice, especially when considering tasks that humans haven’t demonstrated frequently or ever in the pretraining data.
- We could use something like invariant prediction or a domain generalisation approach, which are methods from supervised learning aimed at tackling OOD generalisation. However, this requires knowledge of the “domains” in the training data, which might be hard to come by. Further, it’s also unclear whether these methods really work in practice; and even in principle, these methods are likely insufficient for addressing causal confusion in general, which is a harder problem.
- We could use online training - continually updating our model with new data collected as it interacts with its deployment environment to compensate for the distribution shift. For a system with general capabilities, this would likely mean training in open-ended real-world environments. This seems dangerous - if a shift happens quickly (i.e. due to the agent’s own actions), which then causes a catastrophic outcome, we won’t have time to update our model with new data. Further, this approach might require much greater sample efficiency than currently available, because it would be bottlenecked by the speed of the model’s deployment environment.
- We could do online fine-tuning, for example, RL from human preferences. This approach has received attention recently, but it’s currently unknown to what extent it can address the causal confusion problem. For this to work, fine-tuning will likely have to override and correct spurious correlations in the pretrained model’s representations. This seems like the biggest open question in this argument: is it possible for fine-tuning to fix a pretrained model’s representations, and if so, then how? Is it possible with small enough data requirements that human feedback is feasible?
- We could extract causal understanding from the model via natural language capabilities. That is, perhaps the pretrained model has “read” about causality in its pretraining data, and if correctly prompted can generate causally correct completions or data. This could then be injected back into the model (e.g. via fine-tuning) or used in some hybrid system that combines the pretrained model with a causal inference component, so as to not rely on the model itself to do correct causal reasoning when unprompted in its internal computations (as an analogy example, see this paper for eliciting reasoning through prompting). This approach seems potentially viable but is so far very speculative, and more research is needed.
We can extract several key cruxes from the arguments for and against our position here, and for whether this issue merits concern. These partly pertain to specific future training scenarios, and hence can’t be resolved entirely now (e.g. 1). They’re also determined by (currently) unknown facts about how large-scale pretraining and fine-tuning work, which we hope will be resolved through future work on the following questions:
- Are there spurious correlations in the training data?
- Will spurious correlation be picked up during ERM large-scale pretraining?
- Will the deployment of these models in an interaction and control setting lead to changes too sudden to be handled by continual learning on new data/to what extent would continual learning on new data work?
- Can fine-tuning correct or override spurious correlations in pretrained representations?
To us, it seems like (4) is the most important (and most uncertain) crux, as we currently feel like (1), (2) and (3) will probably hold true to an extent that makes fine-tuning essential to achieving sufficient competence and alignment when deploying models trained offline with ERM in decision-making contexts.
Of course, this is all a matter of degrees:
- Pretraining will likely pick up some spurious correlations, and some of them may be removed by fine-tuning, and deployed models will change the world to some extent.
- The key issue is whether the issues arising from the changes in the world from deployed models combined with the spurious correlations from pretraining can be counteracted by fine-tuning sufficiently well to avoid OOD generalisation failure, especially in a way that leads to objective robustness and alignment.
There are two strands of relevant implications if this flaw turns out to exist in practice. One strand concerns the implications for capabilities of AI systems and the other concerns alignment research.
In terms of capabilities (and forecasting their increase), this flaw would be relevant as follows: If scaling up doesn’t work, then standard estimates for various AI capabilities which are based on scaling up no longer apply - in effect, everything is more uncertain. This also suggests that current large language models won’t be as useful or pervasive - if they start behaving badly in OOD situations, and we can’t find methods to fix this, then they’ll likely be used less.
In terms of alignment research, it mostly just makes everything less certain. If we’re less certain that scaling up will get to AGI, then we’re less certain in prosaic alignment methods generally. If it’s the case that these issues apply differently to different capabilities or properties (i.e. different properties generalise better or worse OOD), then understanding whether properties related to the representation of the model’s goal and human values generalise correctly is important.
Foundation models might be very powerful and transformative and widely deployed even if they do suffer such flaws. Adversarial examples indicate flaws in models’ representations, and remain an outstanding problem but do not seem likely to prevent deployment. A lack of recognition of this problem might lead to undue optimism about alignment methods that rely on the generalisation abilities of deep learning systems, especially given their impressive in-distribution generalisation abilities.
We think this line of reasoning implies two main directions for further research. First, it’s important to clarify whether this argument actually holds in practice: are large-scale pretrained models causally confused, in what way, and what are the consequences of that? Do they become less confused with scale? This is both empirical work, and also theoretical work investigating to what extent the mathematical arguments about ERM apply to current approaches to large-scale pretraining.
Secondly, under the assumption that this issue is real, we need to build solutions to fix it. In particular, we’re excited about work investigating to what extent fine-tuning addresses these issues, and work building methods designed to fix causal confusion in pretrained models, possibly through fine-tuning or causal knowledge extraction.
Appendix: Related Work
Causality, Transformative AI and alignment - part I discusses how causality (i.e. learning causal models of the world) is likely to be an important part of transformative AI (TAI), and discuss relevant considerations from an alignment perspective. Causality is very related to the issues of ERM, as ERM doesn’t necessarily produce a causal model if it can utilise spurious correlations for lower loss. That work mostly makes the argument that causality is under-appreciated in alignment research, but doesn’t make the specific arguments in this post. Some of their suggested research directions do overlap with ours.
Shaking the foundations: delusions in sequence models for interaction and control and Behavior Cloning is Miscalibrated both point out the issue of causally confused models arising from static pretraining when these models are then deployed in an interactive setting. Focusing on the first work, delusions in large language models (LLMs): an incorrect generation early on throws the LLM off the rails later. The LLM assumes its (incorrect) generation was from the expert/human generating the text, as that's the setting it was trained in, and hence deludes itself. In these settings the human generating the text has access to a lot more information than the model, making generation harder for the model, and delusions more likely: an incorrect generation will make it more likely that the model infers the task or context incorrectly.
The work explains this problem using tools from causality and argues that these models should act as if their previous actions are causal interventions rather than observations. However, training a model in this way requires access to a model of the environment and the expert demonstrating trajectories in an online way, and the authors don't describe a way to do this with purely offline data (it may be fundamentally impossible). The phenomenon of delusions is a perfect example of causal confusion arising from static offline pretraining. This serves as additional motivation for the argument we make, combined with the theory demonstrating that even training on all the correct data isn’t sufficient when using ERM.
The second work makes similar points as the first, framed in terms of models being miscalibrated when trained with offline data. It goes on to discuss possible solutions to this such as combining offline pretraining with online RL fine-tuning and possibly using a penalty to ensure the model stays close to human behaviour (to avoid Goodharting) apart from in scenarios where it’s fixing calibration errors.
Performative Prediction is the concept (introduced in the paper of the same name) where a classifier trained on a training distribution, when deployed, induces a change in the distribution at deployment time (due to the effects of its predictions). For example (quoting from the paper), A bank might estimate that a loan applicant has an elevated risk of default, and will act on it by assigning a high interest rate. In a self-fulfilling prophecy, the high interest rate further increases the customer's default risk. This issue is an example of the causal confusion arising from not modelling a model’s outputs as actions or interventions, but rather just as predictions, which is related to the danger of learning purely predictive models on static data and then using these models in settings where their predictions are actually actions/interventions.
Causal Confusion in Imitation Learning introduces the term causal confusion (although similar issues have been described in the past) in the context of imitation learning (which language model pretraining can be viewed as a version of). As described above, this is when a model learns an incorrect causal model of the world (i.e. causal misidentification), which then leads it to act in a confused manner.
They suggest an approach for solving the problem (assuming access to the correct causal graph structure, but not the edges): using (offline) imitation learning first learn a policy conditioned on causal graphs and then train this policy on many different possible causal graphs. Then perform targeted interventions through either consulting an expert or executing the policy in the environment, to learn the correct causal graph, and then condition the policy on it. This method relies on the assumption of being able to have the correct causal graph structure, which seems infeasible in the more general case, but may still provide intuition or inspiration for other approaches to tackling this problem.
A Study of Causal Confusion in Preference-Based Reward Learning investigates whether preference-based reward learning can result in causally confused preference models (or at least, preference models that pick up on spurious correlations). Unsurprisingly, in the regime they investigate (limited preference data, learning reward models for robotic continuous control tasks), the preference models don’t produce good behaviour when optimised for by a policy, which the authors take as evidence that they’re causally confused. Given the setting, it doesn’t make sense to extrapolate these results into other settings; it seems more likely that they can be explained by “learning a hard task (preference modelling) on limited amounts of static data without pretraining leads to models that don’t generalise well out of distribution”, rather than any specific statement about preference modelling. Of course, preference modelling is a task where OOD inputs are almost guaranteed to occur: when we’re using the preference model to train a policy, it’s likely (and even desired) that the policy will produce behaviour not seen by the preference model during training (otherwise we could just to imitation on the preference model training data).
...although it is a bit of a spectrum and our impression is that it is common to retrain models regularly on data that has been influenced by previous versions of the model in deployment.