Neural net / decision tree hybrids: a potential path toward bridging the interpretability gap

[Meta: I feel pretty confident that these are valuable ideas to explore, but low confidence that the best implementation details would be exactly what's in these linked papers.]

[Background on me: I'm a machine learning engineer with a background in neuroscience. I'm much better at building my ideas than I am at explaining my ideas. I've been a fan of Less Wrong and Eliezer's work for a long time and more recently also of Paul Christiano's. I enjoy reading Rohin Shah and Cody Wild's paper summaries, which is how I came across many of these concepts.]


Eliezer Yudkowsky: ...I'd be hopeful about this line of work primarily because I think it points to a bigger problem with the inscrutable matrices of floating-point numbers, namely, we have no idea what the hell GPT-3 is thinking and cannot tell it to think anything else...

Evan Hubinger: +1 I continue to think that language model transparency research is the single most valuable current research direction within the class of standard ML research, for similar reasons to what Eliezer said above.


As an ML engineer, I have interest and ability to work on currently-testable hypotheses for AI safety. It is my hope to find an angle to approach this from which can be tackled on current models but which makes progress towards the goal of reducing existential risk from AGI.

To give you some idea of the motivation behind the ideas in this proposal, let’s explore a rough thought experiment about progress in AI over the next few years. I expect that, given that the lottery ticket hypothesis ( ) is correct, it could be valuable to train models starting with something like 1e16 or more parameters, if there is sufficient economic backing.  Using some yet-to-be-developed analysis technique, then very early in training detect the best found architecture within that possibility space and delete the unnecessary neurons. This would bring us down to perhaps something roughly in range of the estimated 3e14 parameters from Ajeya's report ( ). This idea is basically about doing a broad search of an enormous possibility space of possible architectures and then narrowing in on the best thing that is found, hoping for the 'winning lottery ticket' of an architecture capable of general logical reasoning. This would potentially be an unexpectedly powerful model with unprecedented abilities far outstripping anything which had come before.

The trouble is, if this is the direction research takes, it puts us in the position of having a effective and valuable, yet complex and poorly understood model. The creators would have a large financial incentive to use it in practical applications even if they don't fully trust it because they would want to recoup the significant upfront investment. In this scenario, it would be a very appealing approach to the funding entity to be able to analyze and edit this model without a significant further investment of compute and without substantially decreasing the performance. The question I'm trying to answer in this post is whether there is a tenable path forward from having an enormous black box model to having something comprehensible, editable, and trustworthy.

Given that our current level of interpretability can let us understand well enough to reimplement around 5e4 parameters ( ) and assuming that we continue to make good progress on this over the couple years, I think understanding 5e5 parameters is a reasonable estimate.
As you can see, this leaves us with an 'interpretability gap' between this hypothetical large black box model's 3e14 parameters and the rough estimate of ability to comprehend 5e5 parameters. I have two ideas which I hope in combination can bridge this gap.

In terms of what can specifically be developed with existing tech given a year or two of research, I think two of the most promising approaches are:
To break the big black boxes of models into smaller more comprehensible pieces and connect them up into a modular network which can still deliver approximately the same performance.  
To constrain the outputs of models to reduce the degrees of freedom of the agent and make the output more easily auditable by humans. Both at the output end of the modular network and between the components within the network.

Using just the idea of breaking the network into understandable modules, we'd split our 3e14 parameter model into 6e8 pieces of size 5e5. Since 6e8 is still far too many to individually analyze, we would need to force abstraction / reduce complexity. Even though this process wouldn't necessarily be applied directly to the raw 3e14 parameter model, for the purposes of this estimate we can consider it as reducing the parameter total. Let's say for this thought experiment that we wanted our total number of modules to be around the same number as the parameters within each module, 5e5. If we had 5e5 modules of size 5e5, we'd have roughly 3e11 total parameters. So that leaves our complexity reduction techniques responsible for bringing us down three orders of magnitude from 3e14 to 3e11. That seems pretty plausible to me.




Part 1: Breaking down the black boxes

The main thrust of my idea comes from these two papers:

ANTs: Adaptive Neural Trees 

NBDTs: Neural Backed Decision Trees
website w demo:


This idea has been explored in various forms in recent works. Notably, the Google's Sparsely Gated Mixture of Experts [ ] architecture seems like a step in this direction, but this idea would take the breaking-down even farther.
Ideally, we could break the big black box of a unified model into small enough pieces that they could be individually understood and audited. What's different here than most existing research is that the goal would be to push this breaking-down past the point of peak performance to a point of minor performance loss. Ideally we could get to a point where the black boxes of the 'experts' in the mixture of experts were small enough to audit without too great a drop in performance. If it could be proved that this was possible with not too great a loss of performance, then even a very short-term profit motivated company would possibly be willing to make that trade for the risk reduction from greatly improved interpretability. This could be considered an 'interpretability tax' perhaps. I think there's quite possibly a good compromise where we can make a modular directed acyclic graph of neural nets which have sufficiently constrained information exchange between modules that it makes sense to conceptualize them as separate. The size of the modules and complexity of information exchange should be kept to a limit where we can understand and verify individual modules and their communications.

Part 1a: Tree it after training

The Neural Backed Decision Tree method is developed specifically to add interpretability to an existing neural net -based model. 
The basic gist of the idea is to take the last layer of the neural net and reconfigure that layer into a decision tree. Then observe the tree's reaction out of model samples to label the nodes.
This is very useful considering that new language models, substantially larger than GPT-3 have recently been announced, and this trend towards larger more expensive models may continue if the results of these are economically rewarding.
It's not cheap to produce a new model from scratch, and organizations are unlikely to choose an unproven type of algorithm to base their expensive model on. Much more tractable is the ability to offer these organizations a tool with which they can gain insight into and control over the model they already have.

I think the value in exploring this alternate architecture is that it seems fundamentally more tractable in terms of legibility and editability for the output. On the downside, there's still quite a large black box since you've only broken apart the last layer. On the gripping hand, this paper shows some great results in terms of matching performance of the original model, and I like what they do with the sklearn agglomerative clustering to reshape the existing neural net into a hierarchy.


Example from Neural Backed Decision Trees:



Part 1b: A tree from the start

On the other hand, I think organizing the entire model architecture as a tree has some potential nice properties. If your goal is to make sure there is no unified 'black box' doing incomprehensible untrustworthy computation, then it's not enough to decompose just the last layer of the neural net. You need to deconstruct the entire thing into modular comprehensible pieces. A literal tree may be overly restrictive because you may find that you need some crossing of branches, so perhaps expand this concept to a more general directed acyclic graph.

This could potentially be done with an existing pre-trained model. You could copy the existing layers into the tree architecture with random splits and then fine tune the model to train the splits.

The idea would be to think of the tree as modular. As you identified problematic sections, you could remove that section and replace that particular branch with a different model or a dead end. So it's a way to give humans insight and control into small sections of the model. There are still black boxes, but they are much smaller and interspersed with relatively easily interpretable nodes.

The goal would not be to break it down too far, since having to analyze more pieces also introduces its own complexity and cognitive burden and furthermore over-constriction likely will impair the model's peak competence. Rather the goal is to find the right balance where each 'chunk' of neural net is within the range of expert analysis. One possible estimate of the right size for "careful human inspection can understand this" is the 50k parameters mentioned in the curve-detector reimplementation as summarized in this alignment newsletter. (I haven't read the original work, just this summary.)
In practice, we could test our understanding of a given module by attempting to reimplement and substitute just that piece and then testing the overall network to see if we can get it back to a very similar level of performance after more training.

This could also give you more control over the training process itself. For instance, if one particular branch is performing poorly you could direct more training resources to it specifically. Use the preceding nodes to the target parent node in relatively cheap 'inference mode' to gather data out of a large noisy dataset which ends up at the target parent node. Then train that branch on this new selected dataset. This gets at the need to put extra training resources (data and compute) into chosen target regions. This could be chosen automatically by the model (places where loss was high), or chosen by a human operator (places where value was high, such as e.g. factual medical knowledge). Used in conjunction with the human operator's ability to 'dead end' non-valued branches (e.g. troll comments), this would make the model quite 'steerable' during the training process as well as 'editable' afterwards.
If you had a trained model with a branch which you knew corresponded to a specific expert knowledge area you wanted to query, you could query just that specific branch by forcing the model to do inference using just that branch and its parent nodes.

Also, the ability to selectively identify and delete abilities from the model could be very useful if you had a reason to constrain the model in a particular domain.

Eliezer Yudkowsky: (I think you want an AGI that is superhuman in engineering domains and infrahuman in human-modeling-and-manipulation if such a thing is at all possible.)


Each node essentially has just a segment of the dataset, probabilistically assigned but with most of the probability mass concentrated in some way. You can have another separate algorithm summarize / describe this subset of data in a human readable way. For a language model, this description could be something like a word cloud, or some other sort of example summary.
As each parcel of information gets processed through the model, the original untransformed data is kept associated to facilitate the generation of the human-readable description.

If you needed to remove a bad section of tree, you would already have the dataset you need to restart the training: just use the segment of training data which ended up on this path. You would then have broken a large problem down into a smaller one. You can do things like add a cleaning filter specific to this specific subset of the data or adjust the regulation or other hyperparameters and then retrain just that branch. Once you have a new branch, you can compare its performance to the previous one to make sure the problem has been fixed.

In order to grow without limit to enable scaling to large complex datasets and potentially superhuman capability at fuzzy tasks, I believe it would be necessary to have a way for the nodes to split off new branches as needed. As each node becomes 'information saturated' it would automatically grow a new branch to handle the extra info. In a very large model trained on a large and complex dataset, I imagine there might be 50+ layers in some sections and thousands or millions of terminal leaves.

Information saturation of a node could be measured by a regression that is consistent over several batches. This node could directly be a terminal leaf or it could be further up. If it is further up, its performance drop will need to be inferred by the pattern of performance drop across its terminal leaves. When this occurs, the branch can be rolled back to its best known performance state and a new branch started, then the training resumed. Defining the new split for this new branch can be done by searching for a factor which differentiates the recent data from the older data which was trained on before the best epoch was recorded. This process could be thought of as similar to the role that surprise plays in the brain.

What about the size? Won't it be incredibly inefficient to take what was a stack of x - x - x - x … x layers and split those out into a tree? Each layer would have at least twice as many parameters as the layer above! This would quickly become computationally intractable to train. One possible solution is that, since you are splitting the data into organized subgroups, each layer is dealing with a somewhat simpler set of data. Thus, it should be able to use fewer parameters to learn it to the same degree of efficacy as the layer above. In other words, the set of information is more compressible. There are a variety of techniques for compressing information, such as sparsity or autoencoders. The other benefit of forcing each layer to compress is that it would encourage abstraction. By the the time you got to the terminal leaf, you could have very refined and hopefully interpretable concepts derived from much more limited subsets of the data.

Alright, but what if you want to train just one branch or even just one leaf node? As it is, you'd have to also train all the parent nodes up to the beginning. 
I believe there are options worth exploring for fine tuning subsets of the network. For instance, I think it would be worth exploring Greedy Infomax as a loss function for this purpose.

Their system allows for training nodes based on how well they supply their immediate child nodes with data according to a measure of information maximization. This is why their paper is subtitled "putting an end to end-to-end" since it removes the requirement to back-propagate errors through the entire network.  This would allow more selective training of subsets of the directed acyclic graph.


Part 2: Constraining the outputs to force abstraction


Generally, the idea is to constrain the outputs of models to reduce the degrees of freedom of the agent and make the output more easily auditable by humans. Both at the output end of the modular network and between the components within the network.
My hope is that this will be possible to do with little loss in performance because the natural abstraction hypothesis will turn out to be more or less true [ ]. The reason I think this hypothesis being true would help is that it would mean that the world could effectively be understood and reasoned about in an effective way with relatively simple abstractions as building blocks of communication and thought. Like the syntax of a coding language or the operations of formal logic, these fundamental concepts and patterns of logical thought could be a relatively simple and human-comprehensible set but yet be sufficient for effectively describing and manipulating the universe.

The two papers which have most inspired this idea are:

Compressing the output into limited bandwidth
Leveraging Sparse Linear Layers for Debuggable Deep Networks
Wong et al.
alignment newsletter summary:


Constraining the output to be abstracted code
Ellis et al.

These two papers seem quite different but I think they get at a common idea of enforcing abstraction by bottlenecking the output of a model. The process of training a machine learning algorithm ultimately boils down to observing statistical patterns in data and then generalizing these patterns into rules. This direction of research aims to get the ml algorithm to express these learned rules as sufficiently clear and concise abstractions that we can understand.  I believe this constraint-to-abstraction is promising because I suspect the 'natural abstraction' hypothesis ( ) is correct and thus forcing abstractions, if done correctly, can result in models that are highly performant as well as highly interpretable.


Part 2a

Paper 1 here particularly has conceptual overlap the the NBDT concept. You are constraining the final layer of the network in order to force abstraction which improves interpretability. What I think is interesting here is that this could be a way of making more legible the information exchanged between nodes in the modular network.


Part 2b

Paper 2 is something interesting and different. Imagine you had a powerful complex abstruse optimizer that wrote code to solve problems. Say you had this optimizer create you an agent or set of agents written in code you could understand. You could enforce the constraint that you must be able to fully understand these simpler agents in order to trust them enough to use them in the real world. Once you are happy with the performance of the agents in the test environment, you can copy them, edit them as you see fit, and use them for real world tasks.
The complex 'black box' ai itself never gets to directly interact with the world. You may not fully understand all the computation going on in the code creator, but maybe that's ok if you fully understand the programs it writes. The computer science field already has a lot of well-defined techniques for analyzing untrusted code and validating its safety. There are certainly limits to the efficacy we might maximally expect from code which we could guarantee to be safe, but I think there's a lot of value we could derive from code below our max-complexity-for-validation threshold. Importantly, it should be doable to determine whether a given set of code can be validated or not. Knowing what we can trust and what we can't trust would be a big step forward.

Meanwhile, you could potentially apply the other techniques mentioned to the DreamCoder to make its components more transparent.

An experiment I think would be interesting to try with the DreamCoder would be to take the version trained to reproduce turtle art, and put it in a human-feedback framework where the human trainer started by defining an arbitrary adjective such as 'uplifting' and then selected from pairs of generated images for whichever better fit the adjective. This process seems like it would be easy and entertaining enough that you could get volunteers to do it on a website. This seems interesting to me because it seems like a very simple example of having a model produce a constrained output (human-readable code) which seeks to satisfy a fuzzy human objective (uplifting art vs gloomy art, etc).

I think it would also be worth thinking about if the ideas in DreamCoder could be adapted to work as a more general sort of abstraction for the communication between modules in the network proposed in part 1.

Proposed Research Agenda

The ANTs paper and the NBDT paper focus on vision models. I'm interested in applying this to NLP models. So far, I've checked out and familiarized myself with the code published for these papers. I updated the ANT code to work with Python 3.x and added the ability for it to train on 1d vectors. I then tested its performance on some simple tabulated datasets, and it worked fine.

Step 1: Using the small GPT-2 pre-trained model produced by HuggingFace ( ), copy the layers into a random-split tree.
Step 2: Gather some benchmarks for the unmodified small GPT-2 and the random-split tree. These should be nearly identical at this point.

Step 3: Fine tune the random-split tree so that it has better splits. Also fine-tune the original GPT-2 on the same data. In this, I expect to spend a lot more compute on fine-tuning the tree version. That's ok, I'm willing to put off working on the training efficiency aspect to a different task.

Step 4: retest the fine-tuned models on the benchmarks. They should both do at least as well as the non-fine-tuned versions.

Step 5: Try using TF-IDF on the subsets of data which pass each node in the tree. Generate word clouds on the over-represented words. Create a visualization of this.

Step 6: If all has gone well so far, work on compressing the information between nodes.

Step 7: Work on using smaller layers as the tree is progressed. With the goals of improving legibility, forcing reliance on abstraction, and reducing compute needs for training and inference.


2 comments, sorted by Click to highlight new comments since: Today at 3:33 PM
New Comment

Sounds like a cool project, I'm looking forward to seeing the results!

Now that I got a grant from the Long Term Future Fund and quit my job to do interpretability research full time, I'm actually making progress on some of my ideas!