Review
This is really cool. Great followup work!
I think this is enough to make a hypothesis on how the network works and how the goal misgeneralization happens:
- Somewhere inside the model, there is a set of individual components that respond to different inputs, and when they activate, they push for a particular action. Channel 121 is an example of such a component.
- The last layers somehow aggregate information from all of the individual components.
- Components sometimes activate for the action that leads to the cheese and sometimes for the action that leads to the top right corner.[9]
- If the aggregated "push" for the action leading to the cheese is higher than for the action leading to the top right corner, the mouse goes to the cheese. Otherwise, it goes to the top right corner.
I think this is basically a shard theory picture/framing of how the network works: Inside the model there are multiple motivational circuits ("shards") which are contextually activated (i.e. step 3) and whose outputs are aggregated into a final decision (i.e. step 4).
Thanks! Indeed, shard theory fits here pretty well. I didn't think about that while writing the post.
TLDR: I am trying to understand how goal misgeneralization happens in the same maze-solving network TurnTrout et al. work on. Nothing groundbreaking, but if we are ever to fully understand this model, this is probably an important step. Key findings:
These results suggest that goal misgeneralization may be localizable to specific channels that are not robust to out-of-distribution mazes.
This is my capstone project, created during the last edition of the ARENA program. I want to thank @Joseph Bloom and @Paul Colognese for mentorship, @rusheb and @TheMcDouglas for remarks on the draft of this post, many ARENA teachers and participants for fruitful discussions, and authors of procgen-tools for an excellent toolset.
Introduction
Lauro Langosco et al. trained a maze-solving network, a mouse looking for cheese. During the training, the cheese was always in the top right 5x5 part of the maze. When deployed in an environment where the cheese could be anywhere, the mouse sometimes goes to the cheese, sometimes to the top right corner, and sometimes (although very rarely) gets stuck in some unexpected part of the maze.
This model has already appeared on LW, in the article about the cheese vector by TurnTrout et al. and a follow-up post about the top-right corner vector. I recommend looking at the first one - it has a good description of the model (and I will not repeat it here) and a lot of pretty visuals. I don't directly build on their work, but I used tools they developed while working on it. This toolset is excellent - without it, my research would be incomparably more challenging and time-consuming (probably wouldn't happen at all).
The goal of my research can be summarized as "understand why the mouse sometimes decides to go to the cheese and sometimes to the top right corner". My code is here. All the data in this article was generated using scripts in this repository.
Methods
A first step towards building an understanding of a complex model is usually to split it into smaller parts that will be easy to understand separately. There are many different ways to split a neural network into smaller parts. Goal misgeneralization, in this case, is defined as "mouse going to the top right corner instead of the cheese" - it makes sense to start by investigating which parts of the network are responsible for deciding whether the mouse goes to the cheese or not.
One method of finding parts of the model which encode specific features is to create pairs of inputs where one feature of the environment varies, calculate both forward passes and look for differences between the activations. Parts of the network where activations are similar don't matter; parts where activations changed a lot, are somehow related to the thing you are looking for.
I created a pair of mazes that differ only by a single wall position, but this wall position is crucial for the action that leads to the cheese:
In the first maze, the mouse should go up to get the cheese; in the second maze - right. If a part of the network has the same activations for both mazes, it is not related to the distinction between "go up to get the cheese" and "go right to get the cheese". On the other hand, if the difference in activations is high, we might suspect this part somehow carries the information we are interested in.
We could do a systematic search[1] over the parts of the network, but this is not the main topic of this post. Fast forward, channel 121[2] in the last convolutional layer (
relu3
, the layer just before theFlatten
on the model graph) differs a lot between these two mazes. Here are the pairs of mazes (with rotations), the number below each maze is the sum[3][4] of the activations of this channel:You might have noticed the pattern: when the sum is high, the path to the cheese leads UP. Fast forward again, the same pattern is visible also for other mazes of variable sizes (details in the following sections). We found a candidate for a part of the network that corresponds to "which direction to the cheese" - let's check if we can find a mechanism related to goal misgeneralization there.
In the following part of the post, I will dig deeper into channel 121, but more channels exhibit similar behaviour - I've put a few examples in the appendix.
Looking for goal misgeneralization in channel 121
The sum of channel 121 grouped by cheese position (up/not up)
Let's take a look at the same data as above, but aggregated for a set of random mazes[5], split into "in distribution" group (i.e. with cheese in the top right 5x5) and "out of distribution" (i.e. cheese outside of the top right 5x5):
As expected, the high sum of channel 121 strongly correlates with needing to go up to get the cheese. This pattern is similar both in and out of distribution, but in distribution is much more robust. For example, when in distribution, a value above 30 indicates that cheese is almost certainly up - this doesn't hold for out of distribution mazes. But does this matter at all? Does this difference have any impact on the behaviour of the mouse?
Mouse behaviour when the sum of channel 121 is high
Let's now take a look at a narrow subset of 25x25 mazes:
I will assess mouse cheese-seeking accuracy, i.e. compare "is the path to the cheese up" and "is the most likely path selected by the mouse up".
For in-distribution mazes, the correlation is close to 1 (n=1000):
Not only is the cheese almost always up in this subset of mazes, but the mouse indeed goes up.
Same table for out-of-distribution mazes (n=1000)[7]:
What is striking here is that the mouse still goes up in 95% of mazes - the main difference is that it usually doesn't find the cheese.
So what happens here? My interpretation:
If this interpretation is correct, we should be able to influence mouse behaviour by modifying the values of channel 121. I'll try that in the next section.
Causal intervention experiment
I created yet another set of mazes, this time with cheese exactly in the bottom right corner[8]. I compared the results of full rollouts of two policies:
The hypothesis is that:
This is indeed the case. On a sample of 1000 mazes:
The modified policy finds cheese in 52% of cases, compared to 40% for the original policy. I think this clearly shows two things:
Summary & hypothesis on how the network works
What we know:
I think this is enough to make a hypothesis on how the network works and how the goal misgeneralization happens:
If this hypothesis is correct, the question "How does goal misgeneralization happen?" is reduced to "Why do the components activate on the path to the top right?". We know how to find the components, and we have a good starting point (what happens to channel 121 when the cheese is far from the top right? - first section in the appendix) - I might try to look into this next.
Appendix
"Will the mouse go to the cheese?" vs the distance between the cheese and the top right corner.
We see that the further the cheese is from the top right corner, the lower the chance channel 121 shows the path to the cheese. This is consistent with the behavioural statistics.
Other channels
Channel 73
Data for (73, "go right") is so similar to (121, "go up") that I triple-checked if I'm not computing the same thing twice. Up and right are fully symmetrical in the environment, so this should not be a big surprise, but such similarities are not common in neural networks.
The sum of channel 73 > 30, mouse on the decision square, in distribution:
The same, out of distribution:
Impact of the distance between the cheese and the top right:
Channel 21
The sum of channel 21 activations for different environments:
A high value of this channel seems to be "go down or go left".
Top 10 channels
Here are a few other channels and their interpretations based on the same comparison as for channel 21 above, ordered by the effect size.
I didn't check if all these interpretations generalize to random mazes, but they do generalize for channels 121 and 73 (and they were not cherry-picked).
What happens if I zero a lot of channels
I selected 16 channels that seem the most important from the point of view of the original pair of environments (in the Methods section). This is a vector field difference between the original policy and a policy with these 16 channels zeroed[10] (if the plot is unclear, consult the cheese vector post).
On the one hand, zeroing these 16 channels changes a lot, but on the other hand - the mouse would still go to the cheese and not to the top right. I think the only takeaway here is that even though we have some channels with straightforward interpretation, information is distributed between lots of different channels (that's not a surprise).
Notes on the maze generator
There are some known constraints on the mazes - they are squares with odd sizes, the bottom left and top right corners are always corridors, they are simply connected (i.e. no loops/islands), and there are no inaccessible sections. When you watch random mazes long enough (a week in my case), you might also notice that:
This doesn't look that important at first glance, but:
I think some of this might matter if we are ever to try full mech interp on this model, but I also consider this a general lesson that one should carefully analyse the exact world a model operates in.
Generate a bunch of random mazes, make a forward pass on them and for every activation calculate the standard deviation (or some other similar metric), and compare it to the difference in this particular case.
This channel has the strongest effect but is not unique. I briefly analyse other channels in the appendix.
This is the output of a
ReLU
layer -> there are no negative values -> simple sum makes sense.A natural question: this is a convolution, why look at the sum only, ignoring the spatial structure? Answer: I checked the spatial structure, and the only pattern I found is "high values happen only around the mouse location", and I don't think this matters from the point of view of what I'm trying to do.
Size 25 x 25, mouse in a random square where move UP is legal, 1000 mazes in distribution (i.e. with cheese in the top right 5x5), 1000 mazes out of distribution.
There is no column "neither cheese nor top right corner is up" because this just never happened for this subset of mazes.
Note: the extreme difference between these two tables should probably be discounted by the fact that in distribution decision square is, on average, closer to the cheese/top right than out of distribution - I didn't control for that.
Also, this time maze is 15 x 15. This is because on 25 x 25 mazes with cheese in the bottom right corner success rate is extremely low. Channels in layer
relu3
have the same interpretation for mazes of different sizes.A wild guess why this might be the case: during the training, the mouse first learned to go to the top right corner as a proxy goal and then started to update towards "go to the cheese", and once it updated enough to achieve 100% accuracy we stopped the training - but the old goal was not fully purged.
Zeroing makes sense for channels like 121 or 21 because they often have values close to 0 in normal activations. But there are also important channels that never go down to 0 (e.g. 7 has a value range between ~ 20 and 55) - setting them to 0 doesn't make much sense -> this test is not very good.