TL;DR: Neuroscientists face the same interpretability problem as AI safety researchers: complex, inscrutable systems with thousands of parameters that transform inputs to outputs. I worked on a systematic method to find the minimal features that capture the input-output computation under specific conditions. For cortical neurons with thousands of morphological/biophysical parameters, just three features (spatial input distribution, temporal integration window, recent activation history) predicted responses with 97% accuracy. The approach of searching systematically for sufficient, interpretable features which are relevant for the input-output transformation under a given condition seems transferable to mechanistic interpretability of artificial neural networks.
Epistemic status: Quite confident about the neuroscience methodology (it's part of my PhD thesis work, and is published in a peer-reviewed journal). Uncertain about direct applicability to AI interpretability. This is "here's a tool that worked in a related domain" not "here's the solution to interpretability."
Wait, we're solving the same problem
As I neared the end of my PhD and started looking into AI safety research as something I might want to do next, I was surprised to find that neuroscientists and AI interpretability researchers are working on really similar problems, but we rarely talk to each other.
Both of us have complex, multilayered systems that do something interesting when you give them inputs, and we would really like to know what underlying computation they're actually performing. However, both of us have way too many interacting parameters to reason about all of them simultaneously.
A common approach in neuroscience has been to build very detailed (sometimes billion-dollar) models which are very realistic, then... stare at them really hard and hope that understanding falls out? This lack of meaningful methods to interpret data is starting to be discussed in neuroscience, and I think AI might have a headstart here by having a field explicitly called "interpretability".
What if most of that complexity doesn't matter for the specific behaviour I care about?
Not "doesn't matter" in the sense that it's not happening, neurons definitely have calcium spikes and NMDA nonlinearities. But "doesn't matter" in the sense that you could predict the neuron's output just fine in some cases without modelling all that detail.
This led to a different question: What is the minimal set of features that can predict the system's behaviour under the conditions I actually care about?
This is the question that I worked on together with my colleague Arco Bast, first during my master thesis, and then continued to develop during my PhD.
The methodology: systematic reduction
Quick neuroscience introduction
Neurons in the cerebral cortex receive thousands of inputs per second from thousands of other neurons. They receive these inputs onto their “dendrites”, which branch off from the cell body ("soma"), in the form of “synapses”, which are the connection points between two neurons. Cortical neurons use discrete signals, which means they either produce an output spike, or they don’t. Revealing how synaptic inputs drive spiking output remains one of the major challenges in neuroscience research.
1. Narrow things down to a specific condition
There's a temptation to want general interpretability—to understand the model in all contexts. The problem is, you tend to face some kind of trade-off between accuracy, interpretability and generalisability:
(pick two)
For this reason, we chose the condition of sensory processing of a passive whisker touch in anaesthetised rats, which is a well-characterised condition for which lots of experimental data exists, and for which we have built a highly detailed multi-scale model from this data (we need to use a model here because we need to quantify synaptic input activity to a neuron, which is not currently feasible experimentally - another advantage to AI interpretability!).
2. Don't formulate hypotheses
We didn’t make any top-down assumptions or hypotheses about what the input-output computation of the neurons could look like. We started with biophysically detailed multi-compartmental neuron models embedded in an anatomically realistic network model. These models can reproduce calcium spikes, backpropagating action potentials, bursting, the whole repertoire of cortical neuron activity. They've been validated against experimental data, and when we simulate sensory responses, they match what we see experimentally in actual rat brains.
3. Let the data tell you what's important (search for predictive features)
Instead of hypothesising which features of the input might be important for predicting the neuron’s output, we systematically searched for them in the data. We spent quite some time systematically and iteratively trying different ways of grouping and weighting synaptic inputs, and then comparing the prediction accuracy of the resulting reduced models, eventually deciding to group by:
Time of activation: was this synapse active 1ms ago? 5ms ago? 50ms ago?
Distance from soma: is this synapse close to the cell body, where the output spike can be initialised, or way out in the dendrites?
Excitatory vs inhibitory: you can generally think of excitatory synapses as positively weighted connections, that make the receiving neuron more likely to produce an output spike, and inhibitory synapses as the opposite
Then we used optimisation to find weights for each group that maximised prediction accuracy. Basically: "How much should I weight an excitatory synapse that's 300μm from the soma and was active 4ms ago to predict if the neuron spikes right now?"
This gave us spatiotemporal filters, which in this case are continuous functions describing how synaptic inputs at different times and locations contribute to output:
We took those filters and built generalised linear models (GLMs). With testing, it turned out that we also needed to consider the spike history of our neuron, because real neurons can’t just fire arbitrarily fast. Basically:
What the reduced model told us about neuronal computation
That's it. Despite all the complexity in the original system, all you need to do to predict spiking output under this condition is count active synapses, weight them by location and timing, subtract a penalty if the neuron just fired, and pass that through a nonlinearity.
The reduced model predicted action potential output with 97% accuracy.
And here's the really surprising part: We tested this across seven different neuron models with very different dendritic morphologies, ion channel densities and distributions. They all performed qualitatively the same computation. The filters had slightly different shapes (e.g. scaling with dendrite thickness), but the core input-output transformation was the same.
Reduced models for 7 different neuron models
The insights that might be useful for AI interpretability
1. Focus on a specific condition
In neuroscience, other approaches have tried to build models that captured neuronal responses in all possible experimental conditions (e.g. Beniaguev et al. (2021), who used an 8-layer deep neural network to represent a single neuron). These models end up being so complex that they aren't interpretable. When we constrained to one specific condition, we could actually understand what was happening.
For AI safety: it might be better to prioritise deeply understanding behaviour in safety-critical conditions than shallowly understanding behaviour in general.
If you want to prevent deceptive alignment, you don't need to understand everything GPT-4 does, you mainly need to understand what it does when deception would be instrumentally useful. Figure out the input-output transformation in that condition, and it might be simple enough to reason about.
2. Focus on computation, not implementation
When I analysed what drives response variability (i.e., why different neurons respond differently to the same stimulus) I found network input patterns (which synapses are active when) were the primary determinant of response differences, while morphological diversity and biophysical properties only had minor influence.
What does this mean? Two neurons with completely different "architectures" perform the same computation. The variability in their outputs comes almost entirely from variability in their inputs, not their internal structure.
This suggests a general plausible approach: try focusing interpretability on input patterns and their transformation, not on cataloguing implementation details.
Maybe instead of trying to understand every circuit in GPT-4, we could ask: what input patterns lead to concerning behaviours? What's the minimal transformation from inputs to those behaviours, and can that help us to understand what's going on in the model?
Important Caveats
This worked for one condition: We explicitly focused on passive single-whisker deflections in anesthetised rats. This was a deliberate choice; we traded generality for interpretability. But it means more complex conditions might need more complex reduced models, and you might need multiple models to cover multiple conditions.
When is simple reduction possible? Some behaviors might not admit simple reduced descriptions. For neurons, active whisking (vs passive touch) requires additional features. For LLMs, some behaviors might be irreducibly complex.
Scale: I worked with single neurons receiving thousands of inputs. LLMs have billions of parameters, and context windows keep getting longer.
Wild Speculation Section
Some half-baked ideas that might be interesting:
Compositional models: Neuroscience has found that the same neuron can perform different computations under different conditions (passive touch vs. active exploration, anesthetised vs. awake). Could the same be true of LLMs, and can we find different minimal input-output computations for different contexts that get flexibly combined?
Training dynamics: I reduced neurons at one point in time. What if you tracked how the reduced model changes during a LLM’s training? Could you see a phase transition when the model suddenly learns a new feature or strategy?
Universality: I found the same computation across morphologically and biophysically diverse neurons. Is there universality in neural networks? Do different architectures or training runs converge to the same reduced model for the same task?
Neuroscience has been forced to develop systematic approaches to interpretability because we were struggling to understand biological neural networks due to their many interacting parts (we can’t even measure everything at the same time, AI research should have an advantage here!). AI safety is hitting the same constraint with large language models, so maybe sharing some ideas could help.
Background: I just finished my PhD in neuroscience at the Max Planck Institute for Neurobiology of Behavior. My thesis focused on modelling structure-function relationships in neurons and biological neural networks. Now I'm trying to pivot into AI safety because, honestly, I think preventing AGI from nefariously taking over the world is more urgent than understanding rat whisker processing, and I think transferring established methods and approaches from neuroscience to AI makes sense.
TL;DR: Neuroscientists face the same interpretability problem as AI safety researchers: complex, inscrutable systems with thousands of parameters that transform inputs to outputs. I worked on a systematic method to find the minimal features that capture the input-output computation under specific conditions. For cortical neurons with thousands of morphological/biophysical parameters, just three features (spatial input distribution, temporal integration window, recent activation history) predicted responses with 97% accuracy. The approach of searching systematically for sufficient, interpretable features which are relevant for the input-output transformation under a given condition seems transferable to mechanistic interpretability of artificial neural networks.
Epistemic status: Quite confident about the neuroscience methodology (it's part of my PhD thesis work, and is published in a peer-reviewed journal). Uncertain about direct applicability to AI interpretability. This is "here's a tool that worked in a related domain" not "here's the solution to interpretability."
Wait, we're solving the same problem
As I neared the end of my PhD and started looking into AI safety research as something I might want to do next, I was surprised to find that neuroscientists and AI interpretability researchers are working on really similar problems, but we rarely talk to each other.
Both of us have complex, multilayered systems that do something interesting when you give them inputs, and we would really like to know what underlying computation they're actually performing. However, both of us have way too many interacting parameters to reason about all of them simultaneously.
A common approach in neuroscience has been to build very detailed (sometimes billion-dollar) models which are very realistic, then... stare at them really hard and hope that understanding falls out? This lack of meaningful methods to interpret data is starting to be discussed in neuroscience, and I think AI might have a headstart here by having a field explicitly called "interpretability".
What if we're asking the wrong question?
Neuroscientists spend a lot of time trying to understand everything about how cortical neurons compute. We want to know how every dendritic branch contributed, how calcium spikes in the dendrite interacted with sodium spikes at the soma, and how NMDA receptors enabled nonlinear integration.
What if most of that complexity doesn't matter for the specific behaviour I care about?
Not "doesn't matter" in the sense that it's not happening, neurons definitely have calcium spikes and NMDA nonlinearities. But "doesn't matter" in the sense that you could predict the neuron's output just fine in some cases without modelling all that detail.
This led to a different question: What is the minimal set of features that can predict the system's behaviour under the conditions I actually care about?
This is the question that I worked on together with my colleague Arco Bast, first during my master thesis, and then continued to develop during my PhD.
The methodology: systematic reduction
Quick neuroscience introduction
Neurons in the cerebral cortex receive thousands of inputs per second from thousands of other neurons. They receive these inputs onto their “dendrites”, which branch off from the cell body ("soma"), in the form of “synapses”, which are the connection points between two neurons. Cortical neurons use discrete signals, which means they either produce an output spike, or they don’t. Revealing how synaptic inputs drive spiking output remains one of the major challenges in neuroscience research.
1. Narrow things down to a specific condition
There's a temptation to want general interpretability—to understand the model in all contexts. The problem is, you tend to face some kind of trade-off between accuracy, interpretability and generalisability:
For this reason, we chose the condition of sensory processing of a passive whisker touch in anaesthetised rats, which is a well-characterised condition for which lots of experimental data exists, and for which we have built a highly detailed multi-scale model from this data (we need to use a model here because we need to quantify synaptic input activity to a neuron, which is not currently feasible experimentally - another advantage to AI interpretability!).
2. Don't formulate hypotheses
We didn’t make any top-down assumptions or hypotheses about what the input-output computation of the neurons could look like. We started with biophysically detailed multi-compartmental neuron models embedded in an anatomically realistic network model. These models can reproduce calcium spikes, backpropagating action potentials, bursting, the whole repertoire of cortical neuron activity. They've been validated against experimental data, and when we simulate sensory responses, they match what we see experimentally in actual rat brains.
3. Let the data tell you what's important (search for predictive features)
Instead of hypothesising which features of the input might be important for predicting the neuron’s output, we systematically searched for them in the data. We spent quite some time systematically and iteratively trying different ways of grouping and weighting synaptic inputs, and then comparing the prediction accuracy of the resulting reduced models, eventually deciding to group by:
Then we used optimisation to find weights for each group that maximised prediction accuracy. Basically: "How much should I weight an excitatory synapse that's 300μm from the soma and was active 4ms ago to predict if the neuron spikes right now?"
This gave us spatiotemporal filters, which in this case are continuous functions describing how synaptic inputs at different times and locations contribute to output:
We took those filters and built generalised linear models (GLMs). With testing, it turned out that we also needed to consider the spike history of our neuron, because real neurons can’t just fire arbitrarily fast. Basically:
weighted_net_input = Σ(synapses) × spatial_filter(distance) × temporal_filter(time_ago)
P(spike) = nonlinearity(weighted_net_input - post_spike_penalty)
What the reduced model told us about neuronal computation
That's it. Despite all the complexity in the original system, all you need to do to predict spiking output under this condition is count active synapses, weight them by location and timing, subtract a penalty if the neuron just fired, and pass that through a nonlinearity.
The reduced model predicted action potential output with 97% accuracy.
And here's the really surprising part: We tested this across seven different neuron models with very different dendritic morphologies, ion channel densities and distributions. They all performed qualitatively the same computation. The filters had slightly different shapes (e.g. scaling with dendrite thickness), but the core input-output transformation was the same.
The insights that might be useful for AI interpretability
1. Focus on a specific condition
In neuroscience, other approaches have tried to build models that captured neuronal responses in all possible experimental conditions (e.g. Beniaguev et al. (2021), who used an 8-layer deep neural network to represent a single neuron). These models end up being so complex that they aren't interpretable. When we constrained to one specific condition, we could actually understand what was happening.
For AI safety: it might be better to prioritise deeply understanding behaviour in safety-critical conditions than shallowly understanding behaviour in general.
If you want to prevent deceptive alignment, you don't need to understand everything GPT-4 does, you mainly need to understand what it does when deception would be instrumentally useful. Figure out the input-output transformation in that condition, and it might be simple enough to reason about.
2. Focus on computation, not implementation
When I analysed what drives response variability (i.e., why different neurons respond differently to the same stimulus) I found network input patterns (which synapses are active when) were the primary determinant of response differences, while morphological diversity and biophysical properties only had minor influence.
What does this mean? Two neurons with completely different "architectures" perform the same computation. The variability in their outputs comes almost entirely from variability in their inputs, not their internal structure.
This suggests a general plausible approach: try focusing interpretability on input patterns and their transformation, not on cataloguing implementation details.
Maybe instead of trying to understand every circuit in GPT-4, we could ask: what input patterns lead to concerning behaviours? What's the minimal transformation from inputs to those behaviours, and can that help us to understand what's going on in the model?
Important Caveats
This worked for one condition: We explicitly focused on passive single-whisker deflections in anesthetised rats. This was a deliberate choice; we traded generality for interpretability. But it means more complex conditions might need more complex reduced models, and you might need multiple models to cover multiple conditions.
When is simple reduction possible? Some behaviors might not admit simple reduced descriptions. For neurons, active whisking (vs passive touch) requires additional features. For LLMs, some behaviors might be irreducibly complex.
Scale: I worked with single neurons receiving thousands of inputs. LLMs have billions of parameters, and context windows keep getting longer.
Wild Speculation Section
Some half-baked ideas that might be interesting:
Compositional models: Neuroscience has found that the same neuron can perform different computations under different conditions (passive touch vs. active exploration, anesthetised vs. awake). Could the same be true of LLMs, and can we find different minimal input-output computations for different contexts that get flexibly combined?
Training dynamics: I reduced neurons at one point in time. What if you tracked how the reduced model changes during a LLM’s training? Could you see a phase transition when the model suddenly learns a new feature or strategy?
Universality: I found the same computation across morphologically and biophysically diverse neurons. Is there universality in neural networks? Do different architectures or training runs converge to the same reduced model for the same task?
Neuroscience has been forced to develop systematic approaches to interpretability because we were struggling to understand biological neural networks due to their many interacting parts (we can’t even measure everything at the same time, AI research should have an advantage here!). AI safety is hitting the same constraint with large language models, so maybe sharing some ideas could help.
Background: I just finished my PhD in neuroscience at the Max Planck Institute for Neurobiology of Behavior. My thesis focused on modelling structure-function relationships in neurons and biological neural networks. Now I'm trying to pivot into AI safety because, honestly, I think preventing AGI from nefariously taking over the world is more urgent than understanding rat whisker processing, and I think transferring established methods and approaches from neuroscience to AI makes sense.