Really interesting and impressive work, Joseph.
Here are a few possibly dumb questions which spring to mind:
My apologies! I thought I had responded to this. Better late than never. All reasonable questions.
"What is the distribution of RTG in the dataset taken from the PPO agent? Presumably it is quite biased towards positive reward? Does this help to explain the state embedding having a left/right preference?"
The RTG distribution taken from the PPO agent is shown in Figure 4. It is the marginal distribution of reward, since the reward only occurs at the end of the episode, the RTG and Reward distributions are identical more/less.
If you are referring to the tendency for the agent conditioned on low RTG (0.3) to get RTG closer to (0.8) as "biased towards positive reward". I think this is a function of the training distribution which includes way more RTG~0.8 than 0.3 so yes.
The state embedding having a left/right preference is probably a function of the PPO agent getting success going clockwise, this being reinforced and then this dominating the training data for the DT.
"Is a good approximation of the RTG=-1 model just the RTG=1 model with a linear left bias?"
I'm not sure what you mean by "linear left bias". the RTG=-1 model behaves characteristically differently (see table 1).
"Does the state tokenizer allow the DT to see that similar positions are close to each other in state space even after you flatten? If not, might this be introducing some weird effects?"
Yes I believe the encoding used had that property. I've seen replicated these results with a less nasty encoding and the results are mostly similar if a little easier to interpret.
The ablations seems surprisingly clear-cut. I consider myself to be very on board with "RL-trained behaviors are contextually activated and modulated", but even I wasn't expecting such strong localization.
To be calibrated for RTG values between 0 (non-inclusive) and 1, the decision transformer must reach the goal in a precise number of steps which is difficult given that obstacles move randomly and can extend the amount of time the agent must take to reach the obstacle. To do this well is likely hard, given that there isn’t much training data for low positive RTG values (Figure 4).
Seems to me that randomness wouldn't prevent the agent from bieng calibrated, because even though any given episode might deviate from the prescribed number of steps, on average the randomness can (presumably) be made to add up to that number. EG it might be hard to bump into the goal after exactly 14 steps due to random obstacles, but I'd imagine ensuring this falls between 10 and 18 steps is feasible?
This inductive bias led to a weird behaviour where agents would successfully avoid the obstacles but also try to avoid the goal square.
This seems like an important manifestation of "models don't 'get reward', they are shaped by reward"; even on a simple task where presumably agents can fully explore all relevant options. The e.g. observation encoding (where an obstacle is "3/4" of a goal) matters when predicting what behavioral shards & subroutines get trained into the policy, or considering what e.g. early-stage policies will be like.
(I thiiink I weakly predicted this particular behavior in advance, when I read about the encoding.)
You might have expected that the state embedding wouldn’t project directly in any preference directions. In practice, sometimes we see it does, suggesting that besides likely encoding features, it is directly informing the model's output (which we saw in the residual stream dot product contributions earlier).
Wait, how? Isn't the state observation constant in this task? I'm guessing you're discussing something else?
The ablations seems surprisingly clear-cut. I consider myself to be very on board with "RL-trained behaviors are contextually activated and modulated", but even I wasn't expecting such strong localization.
Neither was I. After the fact, it seems easy to come up with reasons why this might be the case. I think measuring this with something like excluded loss might enable a more precise quantification of exactly how localised it was. I also don't see strong reasons to expect this to generalise to larger models. If I see similar results when the task is more complicated, the model is bigger and especially with a larger context window, then I will be more interested in trying to precisely describe how/why/when you get more localisation.
Seems to me that randomness wouldn't prevent the agent from bieng calibrated, because even though any given episode might deviate from the prescribed number of steps, on average the randomness can (presumably) be made to add up to that number. EG it might be hard to bump into the goal after exactly 14 steps due to random obstacles, but I'd imagine ensuring this falls between 10 and 18 steps is feasible?
I think you might be right in the infinite training data regime, where I would expect it to be unbiased however I suspect that the training data being sparse, especially in the low positive RTG reason is enough to make the signal weak. The loss incurred by finishing in the wrong number of steps is likely very small compared to failing when you have positive RTG or succeeding when you have positive RTG, so it could also be that the model doesn't allocate much capacity to being well-calibrated in this sense.
This seems like an important manifestation of "models don't 'get reward', they are shaped by reward"; even on a simple task where presumably agents can fully explore all relevant options. The e.g. observation encoding (where an obstacle is "3/4" of a goal) matters when predicting what behavioral shards & subroutines get trained into the policy, or considering what e.g. early-stage policies will be like.
I thought this too and should have remembered to cite "Reward is not the Optimization target". I feel like that concept is now more visceral for me. A relevant takeaway that may or may not be obvious is that simulators with inductive biases might be better/worse at simulating particular stuff as a function of their inductive biases. In this case, the positive RTG range was more miscalibrated as a result of the bias.
Wait, how? Isn't the state observation constant in this task? I'm guessing you're discussing something else?
I'm not sure what you mean by constant. If it were constant then the agent/obstacles wouldn't be moving? I'll elaborate a little bit in case that helps.
Since RTG gives information about whether the agent should go forward/not in many contexts, you might have expected (although now I think I wouldn't) the residual stream embedding for the state token not to directly contribute to one logit over another before RTG has been seen. In practice, it seems like it does. For example, in this situation (picture below from the app) with RTG = 90, there is a wall to the left of the agent and the state appears to strongly encourage forward/right. I interpret this as as "some agent behaviours are independent of RTG and can be encouraged as a function of the observation/state before RTG is seen".
To clear up some language as well:
- state encoding -> how is the state represented? weird minigrid schema.
- state token -> the value of a state represented as a vector input to the model.
- state embedding -> an internal representation of the state at some point in the model.
I should have written state-token not state-embedding in the quotes paragraph. Apologies if this led to confusion.
This looks cool, going to read in detail later.
Start working on larger/harder RL tasks that involve more complicated algorithms and/or search and/or alignment relevant phenomena such as goal misgeneralization. The way I see this going is Minigrid D4RL < ProcGen < Atari < whatever tasks Gato does.
Note that team shard (MATS stream) is already doing MI on procgen. I've been supervising/working with them on procgen for the last few weeks. We have a nice set of visualization techniques, have reproduced some of Langosco et al.'s cheese-maze policies, and overall it's been going quite well.
Thank you for letting me know about your work on procgen with MI. It sounds like you're making progress, particularly I'd be interested in your visualisation techniques (how do they compare to what was done in Understanding RL Vision?) and the reproduction of the cheese-maze policies (is this tricky? Do you think a DT could be well-calibrated on this problem?).
Some questions that might be useful to discuss more:
Glad to hear your progress is going well! I'll be in the Bay Area for EAG if anyone from the team would like to chat.
We're studying a net with the structure I commented below, trained via PPO. I'd be happy to discuss more at EAG.
Not posting much publicly right now so that we can a: work on the research sprint and b: let people preregister credences in various mechint / generalization propositions, so that they can calibrate / see how their opinions evolve over time.
Are you using decision transformers or other RL agents on procgens ? Also, do you plan to work on coinrun ?
We're analyzing the mech-int-ungodly Impala architecture, from the paper. Basically
=== Impala
conv
maxpool2D
---- residual x2:
relu
conv
relu
conv
residual add from input to this residual block
=== /IMPALA
(repeat 2 more impalas)
---
relu
flatten
fully connected
relu
---
linear policy and value heads
so this mess has sixteen conv layers, was trained on pixels. We're not doing coinrun for this MATS sprint, although a good amount of tooling should cross over.
This has presented some challenges -- no linearity from decomposing an ongoing residual stream into head-contributions.
Oh nice, I was interested on doing mechanistic interpretability on decision transformers myself and had gotten started during SERI MATS but now was more interested in looking into algorithm distillation and the decision transformers stuff fell to the wayside(plus I haven't been very productive during the last few weeks unfortunately). It's too late to read the post in detail today but will probably read it in detail and look at the repo tomorrow. I'm interested in helping with this and I'm likely going to be working on some related research in the near future anyway. Also btw I think that once someone gets to the point that we understand what's going on the setup from the original dt paper it would be interesting to look into this: https://arxiv.org/abs/2201.12122
Also the dt paper finds their model generalizes to bigger rtg than the training set in the seaquest env and it would be interesting to get a mechanistic explanation of why that happens (tough that's an atari task and I think you are right in that that's probably going to have to come later cause CNN are probably harder to work with).
Another thing to note is that OpenAI's VPT while it's technically not a decision transformer (because it doesn't predict rewards if I remember correctly) it a similar kind of thing in that is Ofline RL as sequence prediction, and is probably one of the biggest publicly avaliable pretrained models of this kind. There's also multiple open source implementation of Gato that could probably be interesting to try to do interpretability on. https://github.com/Shanghai-Digital-Brain-Laboratory/BDM-DB1
Also training decision transformers on minerl(or in eleuther's future minetest enviroment) seems like what might come next after atari(the task gato is trained are mostly atari games and google stuff that is not publicly avaliable if I remember correctly)
(sorry if this is too rambly I'm half asleep and got excited because I think work on dt is a very potentially promising area on alignment and was procrastinating on writing a post trying to convince more people to work on it, and I'm pleasantly suprised other people had the same idea)
Hi Victor,
Glad you are keen on this area, I'd be very happy to collaborate. I'll respond to your comments here but am happy to talk more after you've read the post.
The linked paper (Can Wikipedia help Offline Reinforcement Learning) is very interesting in a few ways, however, I'd be interested in targeted reasons to investigate this specifically. I think working with larger models is often justified but it might make sense to squeeze more juice out of the small models before moving to larger models. Happy to hear the arguments though.
Also the dt paper finds their model generalizes to bigger rtg than the training set in the seaquest env and it would be interesting to get a mechanistic explanation of why that happens (tough that's an atari task and I think you are right in that that's probably going to have to come later cause CNN are probably harder to work with)
I'm glad you asked about this. In terms of extrapolation, an earlier model I trained seemed to behave like this. Analysis techniques like the RTG Scan functionality in the app present ways to explore the mechanisms behind this which I decided not to explore in this post (and possibly in general) for a few reasons:
It's not clear to me that this is more than a coincidence. I think it could be that in the space of functions that map RTG to behaviour, for certain games, it is possible to learn coincidentally extrapolating functions. If the model were to develop qualitatively different behaviour (under some definition) in out-of-distribution RTG ranges for any task, then my interest will be renewed.
I suspect doing something like integrated gradients for CNN layers is pretty doable (maybe that's what the MATS shard team have done, see one of Alex Turner's comments) but yeah, they are probably harder to work with.
Another thing to note is that OpenAI's VPT while it's technically not a decision transformer (because it doesn't predict rewards if I remember correctly) it a similar kind of thing in that is Ofline RL as sequence prediction, and is probably one of the biggest publicly avaliable pretrained models of this kind. There's also multiple open source implementation of Gato that could probably be interesting to try to do interpretability on. https://github.com/Shanghai-Digital-Brain-Laboratory/BDM-DB1
Thank you for sharing this! I'd be very excited to see attempts to understand these models. I've started with toy models for reasons like simplicity and complete control but I can see many arguments in favour of jumping to larger models. The main challenge I see would be loading the weights into a TransformerLens model so we can get the cache enabling easy analysis. This is likely quite doable.
Also training decision transformers on minerl(or in eleuther's future minetest enviroment) seems like what might come next after atari(the task gato is trained are mostly atari games and google stuff that is not publicly avaliable if I remember correctly)
Interesting!
About the sampling thing. I think a better way to do it that will work for other kind models would be trainining a few diferent models that do better or worse on the task and use different policies, and then you just make a dataset of samples of trajectories from multiple of them. Wich should be cleaner in terms of you knowing what is going on on the training set than getting the data as the model trains (wich on the other hand is actually better for doing AD)
That also has the benefit of letting you study how wich agents you use to generate the training data affects the model. Like if you have two agents that get similar rewards using diferent policies does the dt learn to output a mix of the two policies or what exactly?. The agents don't even need to be neural nets, could be random samples or a handcrafted one.
For example I tried training a dt-like model(though didn't do the time encoding) on a mixture of a DQN that played the dumb toy env I was using (the frozen lake env from gym wich is basically solved by memorizing the correct path) and random actions and it apparently learned to ouput the correct actions the DQN took on rtg 1 and a uniform distribution of the action tokens on rtg 0.
Agreed about the better way to sample. I think in the long run this is the way to go. Might make sense to deal with this at the same time we deal with environment learning schedules (learning on a range of environments where the difficulty gets harder over time).
I hadn't thought about learning different policies at the same RTG. This could be interesting. Uniform action preference at RTG = 0 make sense to me, since while getting a high RTG is usually low entropy (there's one or a few ways to do it), there are many ways not to get reward. In practice, I might expect the actions to just match the base frequencies of the training data which might be uniform.
TLDR: We analyse how a small Decision Transformer learns to simulate agents on a grid world task, providing evidence that it is possible to do circuit analysis on small models which simulate goal-directedness. We think Decision Transformers are worth exploring further and may provide opportunities to explore many alignment-relevant deep learning phenomena in game-like contexts.
Link to the GitHub Repository. Link to the Analysis App. I highly recommend using the app if you have experience with mechanistic interpretability. All of the mechanistic analysis should be reproducible via the app.
Key Claims
If you are short on time, I recommend reading:
I would welcome assistance with:
I’m also happy to collaborate on related projects.
Introduction
For my ARENA Capstone project, I (Joseph) started working on decision transformer interpretability at the suggestion of Paul Colognese. Decision transformers can solve reinforcement learning tasks when conditioned on generating high rewards via the specified “Reward-to-Go” (RTG). However, they can also generate agents of varying quality based on the RTG, making them simultaneously simulators, small transformers and RL agents. As such, it seems possible that identifying and understanding circuits in decision transformers would not only be interesting as an extension of current mechanistic interpretability research but possibly lead to alignment-relevant insights.
Previous Work
The most important background for this post is:
Methods
Environment - RL Environments.
GridWorld Environments are loaded from the Minigrid python package. Such environments have a discrete state space and action space where an embedded agent can turn right or left, move forward, pick up and drop objects such as keys, and use objects such as a key to open a door. Each gridworld involves a different task/environment, and most involve very sparse rewards. The observations can be rendered full or partial (from the agent’s point of view). In this work, we use partial views.
Dynamic Obstacles Environment
This environment involves no keys or doors. An agent must proceed to a green goal square but will receive a -1 reward if it walks into a wall or obstacle (ending the episode). The obstacles move randomly every timestep, regardless of what the agent does. The agent received 1 reward (or time-discounted reward with PPO). The goal square does not move, and the space is always a fixed-size grid.
Decision Transformer Training
Trajectories are generated with an implementation of the PPO algorithm developed as part of the ARENA Coursework. We modified code from the original decision transformer paper for the model architecture to work with arbitrary mini-grid environments and to use TransformerLens. We flatten observations and use a linear projection to create the state token. We remove tanh activations (used in the original DT architecture in state/RTG/action encodings) and LayerNorm everywhere to have more linearity to facilitate analysis.
Training of the decision transformer is performed as described in the original decision transformer paper, although we have not implemented learning rate schedules.
Attribution and Preference Directions
We can perform logit attributions by decomposing the final logits for [forward, left, and right] actions into a sum of contributions for each component (attention heads, MLPs, input tokens). This analysis can be conceived of as a single-step version of the integrated gradients method described in Appendix B of Understanding RL Vision.
However, since softmax is invariant under translation, logits aren't inherently meaningful. It is useful to construct the notion of a “preference direction”, which is the difference between the attribution to one action, such as forward, minus the attribution to another, such as right. Figure 3 shows preference direction decomposition. In the Dynamic Obstacles task studied, the actions possible are only forward/right/left, meaning that the pairwise preference of the model loses some information (an action is left out). Still, we find that right and left logits correlate heavily, making a pairwise analysis useful.
An Interface for live Circuit Analysis
Inspired by OpenAI’s approach to the Interpreting RL Vision project, I built an app in streamlit which would enable me to generate trajectories by playing mini-grid games myself whilst observing analysis of model activations and preferences over actions.
Results
Training a small Decision Transformer
We first trained an RL agent via the proximal policy optimisation (PPO) algorithm, which solved the task well and generated trajectories of varying rewards. Ideally, we might have written specific code to sample PPO agents of varying quality; however, it was much quicker to simply store the trajectories generated during PPO training.
With these trajectories, we trained decision transformers of varying sizes (0, 1, 2 layers) as well as width/hidden dimension size (32, 64, 128), number of heads (2, 4) and context length (1, 3 and 9 timesteps worth of tokens). We evaluated the decision transformer during training by placing it in the dynamic obstacle environment and observing performance when conditioned on high reward.
We found that the transformer that seemed to robustly achieve good performance was a 1 layer transformer with a hidden dimension of 128, two heads and a single time step.
Characterising the calibration of this transformer (to confirm it was a good simulator of trajectories of varying quality and not simply a good agent), we generated calibration curves showing average performance achieved for a range of Reward-to-go’s ( RTGs - the target reward the agent is conditioned to achieve).
Now Figure 5 appears to show a highly miscalibrated simulator, at least insofar as the curve is s-shaped rather than linear. I would argue, however, that it is certainly good enough for analysis for the following reasons.
Having trained a decision transformer and verified that it can generate trajectories consistent with various levels of reward, we proceeded to analyse the decision transformer.
Black Box Model Characterisation
To achieve calibrated performance at RTG = -1, RTG =0 and RTG ~ 1, as shown above, the model must robustly learn 3 different behaviours that activate different levels of RTG. Table 3 lists these and their corresponding RTGs. RTG = 1 is impossible to achieve with a time-discounted reward so we use RTG = 0.90 instead.
RTG
-1
0
0.90
Wall Avoidance
yes*
yes
yes
Obstacle Avoidance
no
yes
yes
Goal Avoidance
yes
yes
no
Table 1: RTG Modulated Behaviours. Wall/Obstacle/Goal Avoidance means not walking into those objects. The agent must not reach the goal when RTG is negative nor hit walls or obstacles when it is positive. At 0 RTG, it must avoid walls, obstacles and the goal. *Notably, the agent could learn to achieve -1 RTG by walking into walls but appears not to do so.
Decision Transformer Architecture
The model analysed has a single layer, two attention heads and a context window that only includes the current state and the RTG. Because of the lack of layer norms or non-linearities, which we removed, we can perform an analysis where we decompose the contributions of each component to the preferred direction. Figure 6 describes the architecture.
Worth noting:
Circuit Results Summary
The analysis in this post involves playing the game, doing ablations, and visualising attributions and weights for different components to try to work out what's happening. I don’t consider any of this rigorous evidence but proof of concept (and practice!).
That being said, the transformer appears to learn various trigrams (RTG, f(State) -> Action) combinations, where f(state) includes things like whether the object in front of the agent is the goal, an obstacle or a wall. There are also broader biases that I haven’t fully understood, such as the tendency to manoeuvre the grid clockwise.
The extent to which these trigrams can be localised to specific components varies, but broadly speaking:
The rest of the analysis shows:
If you have no time, skip ahead to the obstacle avoidance circuit. It’s the most interesting section.
Attention and Head Ablation
Inspecting the dot products of components of the residual stream with the forward/right direction can hint at the responsibilities of model components, but performing ablations rapidly demonstrates the value of different components. Looking at each of the behaviours defined above, we ablated Head 0, Head 1 and the MLP one by one and found the following responsibilities, summarised in Figure 7. We used mean ablations for this write-up but found similar results for zero ablation.
To substantiate these results, I’ll share a few example cases. Wall avoidance and Goal seeking are associated with the initial state embedding. This can be seen by inspecting the dot product of this component in the forward/right direction, which is clearly significant when not facing walls (Figure 9) and facing the goal (Figures 8, 10). Wall following in the clockwise directions appears to be contributed to by both heads and the MLP.
Ablation of Head 0 makes the DT reach the Goal at RTG = -1 from the top
Since an agent which reaches the goal receives a positive reward, the transformer refuses to enter the goal in cases where RTG = -1. However, ablation of Head 0 both directly reduces the projections against the forward/right direction and leads the MLP to contribute less to this negative direction (i.e. right over forward).
Ablation of Head 1 makes the DT hit obstacles at RTG = 0.9
Since an agent which walks into an obstacle receives a negative reward, the transformer refuses to walk into obstacles to the goal in cases where RTG is > 0.2-0.3. However, the Ablation of Head 1 directly reduces the projection against the forward/right direction and leads the MLP to contribute less to the opposing direction.
Ablation of MLP makes the DT approach the Goal at RTG = -1 from the left
Since an agent which reaches the goal receives a positive reward, the transformer refuses to enter the goal in cases where RTG = -1. We’ve seen that ablating head 0 can encourage entering the goal from the top. However, it appears that Head 0 does not likewise inhibit forward into the goal from the left. In fact, ablation of the MLP will do this.
Analysing Embeddings
Having found concrete evidence showing us where various behaviours originate in our model, we can start going through the model architecture to see how the model reasons about each of the inputs.
Analysing the RTG and Time Tokens
Reward-to-Go
The Reward-to-Go token is a linear projection of one number, the RTG. For the RTG to affect the decision of the transformer, the “information” in this token must be moved to the state token, which will be projected onto the action prediction.
Nevertheless, it appears the network learnt to project almost 1:1 in the forward/right direction as the dot product of the RTG embedding with the forward/right direction is ~1.13 (This means if the information was moved to the state token it would contribute to forward/right). The dot product of the right/left direction with the RTG embedding is ~ -0.16. This can be interpreted as the RTG embedding encouraging forward movement and probably being fairly ambivalent with rest to right/left when this embedding is attended to in the attention heads later in the model.
Time
The Time embedding is a learned embedding that essentially functions as a learned look-up table of vectors as a function of the integer time value. Unlike a positional embedding in a regular transformer, which is of the same form, this embedding is added to groups of 3 tokens which make up one timestep, an RTG, a state, and an action group. The time embedding is added to the state token (and RTG token).
Analysing the State Embeddings
Minigrid uses a 3-channel encoding scheme, providing the agent with information about the objects, colours and states in a 7 by 7 grid around itself (when you use a partial observation view, which we do in this project). Figure 12 is useful for understanding the Minigrid observation encoding schema
Since I hadn’t thought about it until a decent way through my original analysis, I didn’t one-hot encode the state. This introduces a significant inductive bias induced by spurious index ordinality. For example, in the object channel, the goal is 8, and the obstacle (ball) is 6. So an obstacle looks like ¾ s of the goal, which is really dumb.
Figure 13 visualises each channel of the state/observation, with colour representing the index. In the object view, obstacles are more distinguishable from empty space than in the colour channel. The state channel provides no information since it’s used to show whether doors/boxes are locked, unlocked or open, which isn’t relevant to the Dynamic Obstacles task.
This inductive bias led to a weird behaviour where agents would successfully avoid the obstacles but also try to avoid the goal square. This was fixed when I started over-sampling the final step of trajectories, a hack I think wouldn’t be necessary for the one-hot encoded version of the model but is also a generally good trick for getting agents to train faster on mini-grid tasks.
It’s possible to look at the activation at each position in each channel projected into the forward/right direction since each position in each channel projects into the residual stream (128 dimensions) which is itself eventually projected into the action logits (see Figure 14).
There are a few interpretations worth taking away from here:
Nevertheless, I do think there is some valuable information in Figure 14. We know from our observations above that the state embedding mostly works as a wall detector.
The square in front of the agent in the colour embedding has a large negative value (-0.742). Empty Squares are 0 in colour, obstacles are 2, and walls are 5. This square thus will project weakly onto “don’t walk forward when there is an obstacle in front” and strongly onto “don’t walk forward when there is a wall in front”.
The square in front of the agent in the object embedding (left in Figure 14) has a slightly positive value (0.27), so it encourages walking into walls and obstacles but encourages walking into the goal square more than other objects.
This is cool because it suggests that the inductive bias associated with my idiotically not one-hot-encoding the vision is directly why the model avoids walls much more strongly than it avoids obstacles. It also explains why early versions of the model avoided the goal square.
I’ll make a table to show how the two neurons for object/state embedding for the square directly in front of the agent combine to respond to objects/colours.
Object/
Colour
1
0
0.27 * 1 = 0.27
2
5
0.27 * 2 - 5 * 0.742 = -3.17
6
2
0.27 * 6 - 2 * 0.742 = 0.136
8
1
0.27 * 8- 1 * 0.742 = 1.42
Table 3: Back of Envelope Calculation for Object and Colour Embeddings activations dot product with forward/right direction for the position directly in front of the agent. The results suggest embedding functions as a strong wall avoider and a weak goal square seeker.
Table 3 provides evidence for how the state embedding detects walls directly in front of the agent. However, there are lots of other neurons projecting into the forward/right direction in the state embedding, as shown in Figure 15.
Let’s summarise:
Explaining Obstacle Avoidance at positive RTG using QK and OV circuits
To explain how our decision transformer avoids obstacles at positive RTG, we can reference the QK and OV circuits, the encoding scheme, and the RTG attribution. Briefly, the QK (Query-Key) circuit controls which token the head attends to, and the OV (Output-Value) circuit determines how attending to each token affects the logits. More details are here.
First, we need the circuit to activate at a higher RTG. The circuit needs to inhibit moving forward (which we’ll approximate roughly with the forward/right direction), so it will want to attend to the state (not the RTG, since we know it will project positively into forward/right direction and the forward logit).
Attention to the RTG will be high where the query and key vectors match (key/source is the RTG token and the query/destination token is the state). The QK circuit visualisation in Figure 16 shows how values RTG attention is inhibited by anything that isn’t an empty square in front of the agent.
Now, we can look at the OV circuit for attention head 1 to determine the effect on the output from attending to the state rather than the RTG (Figure 17). It appears that high object channel values in front of the agent will lead to a decrease in the forward logit. This includes obstacles (6) and the goal (8) (see table 3).
Lastly, it’s worth noting that the MLP tends to accentuate the output of either attention head, helping them project strongly enough to counteract the strong forward/right projection we usually see coming from the state embedding when not facing a wall. Thanks to Callum for pointing out that we can measure cosine similarity between input directions of MLP neurons and directions of the OV circuit to provide more detail here. I plan to do this soon.
Discussion
Analysis takeaways
Limitations
Some limitations worth highlighting:
I tweaked hyperparameters to encourage as interpretable a model as possible. This means training for a long time with no meaningful loss decrease to encourage weight decay to favour a generalising solution over a memorising solution. I suspect that, in many cases, models can be performant and lack nice abstractions.
Alignment Relevance
A more general list of why MI might be useful to AI alignment is provided here.I think that Decision Transformer MI might be useful to AI alignment in several ways:
Capabilities Externalities Assessment
I believe strongly that given the precipice we find ourselves standing on, it behoves everyone publishing empirical work to consider how it might backfire from an alignment perspective. Part of this means making a public commitment to care about capabilities, which you can consider as doing. I promise you, seriously if you think this or any other work I’m doing might enhance capabilities. If you feel this is the case, please consider the best way to notify me of this, probably an email (so as not to highlight any capabilities enhancing insight further in a public forum) but feel free to CC anyone you think would be a reasonable third party.
Before publishing this work, I considered several factors and spoke to various alignment researchers. I have published it because it is a reasonable trade-off between empirical progress on alignment and possible capabilities enhancement.
For reference, here are some posts on capabilities/MI.
Future Directions
This project was designed to provide early feedback for what could be a much longer research investigation involving larger systems and more comprehensive analyses followed by interventions such as model editing.
Meta
Acknowledgements
I’m very grateful to
Many thanks to all the people who have given feedback on this draft, including Callum McDougal, Dan Braun, Jacob Hilton, Tom Lieberum, Shmuli Bloom and Arun Jose.
Author Contributions
Paul Colognese wrote up the initial project description, and he, Joseph and Callum McDougall had a few early meetings to discuss the project. The PPO algorithm was almost entirely taken from the ARENA curriculum written by Callum, based on MLAB content. The Decision Transformer code was based heavily on the original code base, which is still a submodule of the GitHub repository.
The rest of the work, building the code base, training the agents, building the interactive app and writing this post, was done by Joseph Bloom, whose perspective the write-up is written in.
Feedback
Please feel free to provide feedback below or email me at jbloomaus@gmail.com. The GitHub repository is a reasonable place for technical feedback. Feedback on the app/codebase would be appreciated too!
Glossary
I highly recommend Neel’s glossary on Mechanistic Interpretability.