In this experiment, I study the behaviour of a Deep Q-Network agent with attention based architecture driving a vehicle in a simulated traffic environment. The agent learns to cross an intersection without traffic lights and under a dense traffic setting while avoiding collisions with other vehicles.
As first part of this experiment, I first train the agent using the attention based architecture. Later, I study the behaviour of the agent by applying some interpretability techniques on the trained Q-network and find that there is some evidence to show that network comprises of 3 different layers serving specific functions, namely - sensory (embedding layers), processing (attention layer) & motor (output layers).
The purpose of this experiment is to gain deeper understanding of the agent from interpretability perspective which may be used for developing safer agents in real world applications.
With increasing usage and deployment of autonomous driving vehicles on roads, it is important that the behaviour of these autonomous agents is thoroughly tested, understood and analysed from a safety perspective.
To measure safety, one of the traditional approaches involves running large number of experiments in a simulated environment and collecting statistics on the number of failures cases. While this is certainly useful and gives an overall perspective on the failures modes of the agent, however it does not say anything about the specificity of those failures.
One may prefer to delve deeper and study why a particular agent failed and understand if that behaviour was a consequence of agent's actions or a conditioning of that environment. This calls upon applying some of the interpretability techniques on the agent's behaviour and derive more specific conclusions on what features the agent senses from the environment and what decisions it takes.
For this experiment, I study the behaviour of a trained agent by applying some interpretability techniques on the policy network of the model and share my observations and conclusions derived from the experiment.
The agent under study is only trained and deployed in a simulated environment (with enough simplifications) and is far from a real world setting and it's complexities. While this does not really represent the behaviour of the agent in the real world, I still think a study like this can be worthwhile in providing some insights on what kind of decision making process is learned by the agent and how it can be used to make agents more safer.
Now, to give some context on the problem, let's understand how
Then I will share my observations and insights.
The environment used in this experiment is an Intersection-env which is a customised gymnasium type that has agent-environment loop.
The environment setting contains total N=15
vehicles at any given point in time.
The agent in question is controlling the green vehicle while the blue vehicles are simulated by traffic flow model which in this case is controlled by intelligent driver model. The intelligent driver model is less nuanced and lacks any complex behaviour in comparison to the agent in question. The blue vehicles are spawned at random points initially.
And below is the animated image of a trained agent crossing the intersection.
The joint observation of a road traffic with one agent denoted - s0
and other vehicles - N
is described by a combined list of individual vehicle states:
where
Individual values of each state variables are described as follows:
Feature | Description |
---|---|
presence | Disambiguate agents at 0 offset from non-existent agents. |
x | World offset of agent vehicle or offset to agent vehicle on the x axis. |
y | World offset of agent vehicle or offset to agent vehicle on the y axis. |
vx | Velocity on the x axis of vehicle. |
vy | Velocity on the y axis of vehicle. |
heading | Heading of vehicle in radians. |
cosh | Trigonometric heading of vehicle. |
sinh | Trigonometric heading of vehicle. |
The vehicle kinematics are described by Kinematic Bicycle Model. More on this topic can be found here
The agent drives the vehicle by controlling its speed chosen from a finite set of actions A = {SLOWER, NO-OP, FASTER}
.
Rewards:
Reward | Action |
---|---|
1 | Agent driving at maximum velocity |
-5 | On collision |
0 | Otherwise |
The agent used in this experiment uses a DQN algorithm with attention based architecture which was first proposed in the paper on - Social Attention for Autonomous Decision-Making in Dense Traffic [1].
For this experiment, I delve deeper on the agent's policy network since that network encodes the decision making of the agent.
Here's how the network looks like:
Layer Name | Dimensions |
---|---|
Ego & Others embedding layer - 0 | 7x64 |
Ego & Others embedding layer - 1 | 64x64 |
Attention Layer Query | 64x64 |
Attention Layer Key | 64x64 |
Attention Layer Value | 64x64 |
Output Layer - 0 | 64x64 |
Output Layer - 1 | 64x64 |
Output Layer - predict | 64x3 |
Embedding layers:
It is composed of several linear identical encoders, a stack of attention heads, and a linear decoder.
There are two embedding layers:
1. Ego embedding - dedicated to tracking features for vehicle driven by the agent itself.
2. Others embedding - dedicated to tracking features for vehicles driven by other agents.
Attention layers:
Essentially, a single query Q = [] and a set of keys K = [] are emitted by doing linear projections on the state of the environment. Here, N is the number of vehicles including the agent's vehicle.
The outputs from all heads are finally combined with a linear layer, and the resulting tensor is then added to the residual networks.
Authors of the paper claim that an agent with the proposed attention architecture shows increased performance gains in autonomous decision making under a dense traffic setting. Their study involved comparing the performance of the agent against common architectures like FCN and CNN. The social interaction patterns with other vehicles were visualised and studied qualitatively.
As part of this experimentation, I first replicate the behaviour of the agent as described by the authors of the paper. Then, I proceed to study the agent's behaviour by borrowing some well recognised interpretability techniques in the literature like Understanding RL Vision[2] and A Mathematical framework for Transformer Circuits[3].
After collecting the observations, I derive some key insights on the behaviour of the agent and mention some interesting directions for work in future.
List of techniques applied in this experiment are as follows -
Analysis on feature importance in reference to output layer
Source code for training the agent and details on choice of hyper-parameters, model architecture parameters and extended results are in the Appendix section.
First step is to replicate the studies of the paper by training and evaluating the agent in dense traffic setting having single intersection on the road without any traffic lights.
According to my observations, I confirm that the agent successfully learns to cross the intersection while avoiding collisions with other vehicles in most scenarios.
Following animation shows the trained agent navigating through the intersection along with it's attention patterns for the time step when the agent decides to slow down noticing another vehicle in the way.
Notes:
Above results confirm that the agent learns to navigates through the crossing avoiding collisions in most scenarios, by paying attention to the other vehicles in on the crossing. Attention to other vehicles at every time step of the episode are highlighted by thick coloured lines from green to blue vehicles.
First I analyse what insights I can find in the weights of embedding matrices. Since the embedding layers consist of 2 hidden layers, namely, being the first layer with dimension 7x64 and being the second layer with dimension 64x64.
If embedding matrix weights are computed as follows:
then, dimension is 7x64
Checking the embedding matrix for both layers individually reveal the following:
Notes:
y
and cosh
x
, y
, vy
and sinh
This shows us that sensory function of the model is picking up some interesting signals from the environment.
Next I analyse how these features interact with the attention layers of the network.
According to the research done by Anthropic team, they outline a mathematical framework of understanding attention layers.
Quoting from the paper:
Attention heads can be understood as having two largely independent computations: a QK (“query-key”) circuit which computes the attention pattern, and an OV (“output-value”) circuit which computes how each token affects the output if attended to.
Here, I study any emerging QK and OV circuit patterns in the attention layers. To study the emergence of any learned structure, I compare it with the Untrained vs Trained network and find their squared difference measures.
As seen from the code, The agent in this experiment has only 1 layer with 4 attention heads whose vectors are first slice in 4 parts, computed and then later concatenated. Hence, for simplicity, I computed QK and OV circuits for the combined matrix instead of slicing it in 4 parts.
where,
where,
QK and OV circuits:
Furthermore, attention scores and output value matrices are shown in Fig 9a and 9b (in appendix section), show some interesting learned structural patterns.
Above figures show distinctions between the two scenarios, one where the agent is untrained and the other where the agent is trained. The QK and OV Circuits show high activations for for certain lines and areas in the heat map indicating a learned structure/pattern from the interplay between the agent, other vehicles and the environment.
Notes:
y-coordinate
with a strong positive correlation.vy
is being attended to with a negative correlation.y
coordinate and Other vehicles x
, y
coordinates and vy
velocity. Hinting at the possibility of computing a distance metric.y
coordinate feature and Others embedding presence
feature.vx
, vy
) and heading (cos_h
, sin_h
), meaning the model values its own speed and direction when choosing actions.x
, y
) of other vehicles become crucial. The agent learns to consider the positions of other vehicles when deciding whether to slow down, idle, or speed up.Let's study the agent's activations per time step and intermediate attention matrices collected over the full episode run.
Here, I extract environment frames and the activations of Attention head vs Vehicle for each time step of the evaluation shows the following:
Notes:
Vehicle_0
is the green vehicle which is controlled by the agent and hence is always being attended to by all 4 attention heads at all time steps.Next, I study output layers further to understand what insights I can draw from there.
In this section I try to understand which features does the model learn to extract from the environment state. One of the common techniques for finding feature importance is that of computing Integrated gradients. The integrated gradients give an understanding of overall importance of features.
For this scenario, I compare the integrated gradients between the Untrained vs Trained networks averaged over 30 episodes and found the following.
Notes:
Q-net's output layer makes final decision by looking into presence
, x-y
coordinates and sinh
features. All the above graphs pile up more evidence to the previous notes/observations I gathered earlier.
In following section, I make some speculative interpretations about the agent. I would love to validate some of those interpretations by conducting more thorough experiments in the future. For the claims that I am not confident, I have marked them inline.
Following is a walkthrough of the agent in action along with the interpretation on key observations/notes collected so far.
Step 1: Input Features
The agent observes its environment using the following features:
vx
, vy
: Ego vehicle’s velocity (speed in x and y directions).cos_h
, sin_h
: Ego vehicle’s heading direction.x
, y
: Ego vehicle’s position.presence
: Indicates whether another vehicle is nearby.vx
, vy
: Other vehicle’s velocity.x
, y
: Other vehicle’s position.Step 2: Attention Mechanism (QK Circuits)
The Query-Key (QK) circuits determine which features should be attended to when making a decision.
vx
, vy
, cos_h
, and sin_h
to understand its motion.y
) to determine its position in the intersection. (why x
coordinate does not have high activations ? Something that I would like to find more later.)presence
to check if another vehicle is nearby.vx
, vy
, x
, y
of other vehicles to predict their movement.presence
and vx
increases. (unverified claim, is model computing some distance metric ? can we verify this ?)presence
feature) is high.vx
, vy
) suggests a collision risk.cos_h
and sin_h
influence the decision, meaning the agent aligns with road orientation.vx
and sin_h
, meaning the model prefers accelerating when aligned with the road.Step 3: Q-Value Calculation (OV Circuits)
The Output-Value (OV) circuits determine how much each feature contributes to the Q-values for each action.
Feature Contributions to Actions:
presence
is high, the Slow action gets higher weight.vx
of the other car is high, meaning it is moving fast toward the intersection, the agent reduces Q-values for Fast action.vx
is high, but a collision is possible, the Q-value for Idle is also reduced.This experiment shows that a DQN agent with attention based mechanism can learn to cross a road intersection environment under a dense traffic setting with reasonable levels of safety.
Additionally, analysis on attention layers of the agent's Q-network show that there is sufficient evidence to believe that these layers learn some high level Q-policies that drive the decision making of the agent. Although, it was possible to find some high level policies, more work is needed to find how different policies combine together to form a concrete algorithm.
It was shown qualitatively with some level of confidence, that the agent learns to delegate different types of functions to it's embedding, attention and output layers. These layers learn to serve the sensory, processing and motor functions respectively.
This experiment was limited in scope and timing (up to 4 weeks). For this reason, I chose to focus on replicating the behaviour of the agent and running various types of interpretability techniques to narrow down on a promising approach of finding exact behaviour of the agent in further research.
Following are some of the areas that can be explored in future:
(x, y)
coordinates of the other vehicles ?presence
feature and high activations of vx
feature of other vehicles. Do these two policies correlate highly for the agent or largely stay independent ?My sincere thanks to this amazing community who have made Interpretability research easily accessible reachable to general public. I hope that my experiments bring some value to others and to this community. I look forward to delve deeper in this topic, any support & guidance is highly appreciated.
I would also like to thank BlueDot impact for running a 12 week online course on AI Safety fundamentals. I conducted this experiment as part of the project submission phase of this course and I am grateful to their course facilitators and their team for conducting amazing sessions and providing a comprehensive list of resources on the key topics.
I'm looking forward to collaborating. Reach out to me on
My portfolio
LinkedIn
Glossary:
DQN: Deep Q-Network
FCN: Fully Convolutional Net
CNN: Convolutional Neural Net
QK Circuit: Query-Key Circuit
OV Circuit: Output-Value Circuit
Model architecture:
EgoAttentionNetwork(
(ego_embedding): MultiLayerPerceptron(
(layers): ModuleList(
(0): Linear(in_features=7, out_features=64, bias=True)
(1): Linear(in_features=64, out_features=64, bias=True)
)
)
(others_embedding): MultiLayerPerceptron(
(layers): ModuleList(
(0): Linear(in_features=7, out_features=64, bias=True)
(1): Linear(in_features=64, out_features=64, bias=True)
)
)
(attention_layer): EgoAttention(
(value_all): Linear(in_features=64, out_features=64, bias=False)
(key_all): Linear(in_features=64, out_features=64, bias=False)
(query_ego): Linear(in_features=64, out_features=64, bias=False)
(attention_combine): Linear(in_features=64, out_features=64, bias=False)
)
(output_layer): MultiLayerPerceptron(
(layers): ModuleList(
(0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
)
(predict): Linear(in_features=64, out_features=3, bias=True)
)
)
Environment configuration:
N: number of vehicles
Observations type: Kinematics
Observation space: 7
where
Action space: 3 {SLOWER, NO-OP, FASTER}
Hyper-parameters:
Gamma: 0.95
Replay buffer size: 15000
Batch size: 64
Exploration strategy: Epsilon greedy
Tau: 15000
Initial temperature: 1.0
Final temperature: 0.05
Evaluation:
Running evaluation over 10 episodes with display enabled shows high scores and successful navigation through the intersection.
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKInputSession subclass]: chose IMKInputSession_Modern
/Users/mdahra/workspace/machine-learning/rl-interp/.venv/lib/python3.12/site-packages/rl_agents/agents/deep_q_network/pytorch.py:80: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)
return self.value_net(torch.tensor(states, dtype=torch.float).to(self.device)).data.cpu().numpy()
[INFO] Episode 0 score: 8.6
[INFO] Episode 1 score: 5.5
[INFO] Episode 2 score: 3.0
[INFO] Episode 3 score: 9.6
[INFO] Episode 4 score: 8.5
[INFO] Episode 5 score: -1.0
[INFO] Episode 6 score: 9.0
[INFO] Episode 7 score: -1.0
[INFO] Episode 8 score: 6.5
[INFO] Episode 9 score: 7.6
Learned Attention scores: