A visualization of a sparse computational graph pruned from a RNN. Square nodes represent neurons and circles are states from the previous timestep. Nodes and edges are according to their current output with blue being negative and red positive."/><meta data-react-helmet="true" name="twitter:image:alt" content="A visualization of a sparse computational graph pruned from a RNN. Square nodes represent neurons and circles are states from the previous timestep. Nodes and edges are according to their current output with blue being negative and red positive.

This is a linkpost for a writeup on my personal website: https://cprimozic.net/blog/growing-sparse-computational-graphs-with-rnns/

Here's a summary:

This post contains an overview of my research and experiments on growing sparse computational graphs I'm calling "Bonsai Networks" by training small RNNs. It describes the architecture, training process, and pruning methods used to create the graphs and then examines some of the learned solutions to a variety of objectives.

Its main theme is mechanistic interpretability, but it also goes into significant detail on the technical side of the implementation for the training stack, a custom activation function, bespoke sparsity-promoting regularizer, and more.

The site contains a variety of interactive visualizations and other embeds that are important to its content.  That's why I chose to make this a linkpost rather than copy its content here directly.

I'd love to receive any feedback you might have on this work.  This topic is something I'm very interested in, and I'm eager to hear peoples' thoughts on it.

New to LessWrong?

New Comment
5 comments, sorted by Click to highlight new comments since: Today at 11:02 AM
[-]gwern9mo130

Given all your problems with making differentiability/backprop work, interest in exotic activation functions and sparsity-encouraging regularization, optimization difficulties, 'grokking' fragility etc (where momentum is potentially helping tunnel out of local optima), have you considered dropping backprop entirely and using blackbox methods like evolutionary computing? For such tiny RNNs with infinite datasets, the advantages of backprop aren't that big (and you don't seem to be really exploiting any of the advantages differentiability gives you beyond mere training) and it comes with many drawbacks you have been discovering the hard way, while evolutionary search would give you tremendous freedom and probably find better and more sparser topologies. (It would also probably do a better job avoiding deceptive gradients/local-optima, and you could more easily generate many different solutions to the problem to compare for interpretable sub-structure.)

Another thing to try is to avoid regularization early on, and instead train large networks and then prune them down. This usually works better than trying to train small nets directly - easier to optimize, more lottery tickets, etc.

I'm a little concerned about the need for large weights. That requires a lot of information. It might be good to switch to think more in bits and penalize total bits of the network in search of simplicity.

(Also, the top of one figure is slightly cut off in my Firefox: the img needs a bit more padding, or something.)

[-]ameo9mo70

Thanks for your detailed notes!

have you considered dropping backprop entirely and using blackbox methods like evolutionary computing

This is a really neat idea that I'd love to explore more. I've tried some brief experiments in that area in the past, using Z3 to find valid parameter combinations for different logic gates using that custom activation function. I didn't have any luck, though; the optimizer ran for hours without finding any solutions and I fell back to a brute-force search instead.

A big part of the issue for me is that I'm just very unfamiliar with the whole domain. There's probably a lot I did wrong in that experiment that caused it to fail, but I find that there are far fewer resources for those kinds of tools and methods than for backprop and neural network-focused techniques.

I know you have done work in a huge variety of topics. Do you know of any particular black-box optimizers or techniques that might be a good starting point for further exploration of this space? I know of one called Nevergrad, but I think it's more designed to work for stuff like hyperparam optimization with ~dozens of variables or less rather than hundreds/thousands of network weights with complex, multi-objective optimization problems. I could be wrong though!

Another thing to try is to avoid regularization early on

This is an interesting idea. I actually do the opposite: cut the regularization intensity over time. I end up re-running training many times until I get a good initialization instead.

I'm a little concerned about the need for large weights.

"Large" is kinda relative in this situation. Given the kinds of weights I see after training, I consider anything >1 to be "large".

For representing logic gates, almost all of them can be represented with integer combinations of weights and biases just due to the way the activation function works. If the models stuck to using just integer values, it would be possible to store the params very efficiently and turn the whole thing from continuous to discrete (which to be honest was the ultimate goal of this work).

However, I wasn't able to successfully get the models to do that outside of very simple examples. A lot of neurons develop integer weights by themselves, though.

the top of one figure is slightly cut off

Ty for letting me know! I'll look into that.


Ty again for taking the time to read the post and for the detailed feedback - I truly appreciate it!

I know of one called Nevergrad, but I think it's more designed to work for stuff like hyperparam optimization with ~dozens of variables or less rather than hundreds/thousands of network weights with complex, multi-objective optimization problems.

Nevergrad is more of a library than a single specific algorithm:

Nevergrad provides implementations of Covariance-Matrix-Adaptation, Particle Swarm Optimization, evolutionary algorithms, Differential Evolution, Bayesian optimization, HyperOpt, Powell, Cobyla, LHS, quasi-random point constructions, NSGA-II, ...

Offhand, several of these should work on NNs (Ha is particularly fond of CMA-ES), and several of these are multi-objective algorithms (they highlight PDE and DEMO, but you can also just define a fitness function which does a weighted sum of your objectives into a single index, or bypass the issue entirely by using novelty search-style approaches to build up a big library of diverse agents and only then start doing any curation/selection/optimization).

As Pearce says, you should go skim all of David Ha's papers as he's fiddled around a lot with very small networks with unusual activations or topologies etc. For example, weight-agnostic NNs where the irregular connectivity learns the algorithm so randomized weights still compute the right answer.

He has a new library up, EvoJax, for highly-optimized evolutionary algorithms on TPUs, which might be useful. (If you need TPU resources, the TPU Research Cloud is apparently still around and claiming to have plenty of TPUs.)

For your purposes, I think NEAT/HyperNEAT would be worth looking at: such approaches would let you evolve the topology and also use a variety of activation functions and whatever other changes you wanted to experiment with. I agree with Pearce that I'm a bit dubious about hand-engineering such a fancy activation. (It may work but does it really give you more interpretability or other important properties?)

You can also combine evolution with gradients*, and population+novelty search would be of interest. Population search can help with hyperparameter search as well, and would go well with some big runs on TPUs using EvoJax.

Pruning NNs is an old idea dating back to the 1980s, so there's a deep literature on it with a lot of ideas to try.

This is an interesting idea. I actually do the opposite: cut the regularization intensity over time. I end up re-running training many times until I get a good initialization instead.

That sounds like you're kinda cheating/bruteforcing a bad strategy... I'd be surprised if having very high regularization at the beginning and then removing it turned out to be optimal. In general, the idea that you should explore and only then exploit is a commonplace in reinforcement learning and in evolutionary computing in particular - that's one of Stanley & Lehman's big ideas.

"Large" is kinda relative in this situation. Given the kinds of weights I see after training, I consider anything >1 to be "large".

Why not use binary weights, then? (You can even backprop them in several ways.)

* I don't suggest actually trying to do this - either go evolution or go backprop - I'm just linking it because I think surrogate gradients are neat.

[-]ameo9mo10

Wow, I appreciate this list! I've heard of a few of the things you list like the weight-agnostic NNs, but most is entirely new to me.

Tyvm for taking the time to put it together.

  • The optimization section of Learning Transformer Programs might work with your task/model
  • You've probably seen David Ha's work, but something like https://es-clip.github.io/ could be a good starting point for dropping backprop.
  • The exotic activation function almost feels like cheating? Like I want the model the model to discover these useful structures, then try to understand them. But trying to do everything at once may be too hard. 
  • Incredibility minor, but changing from onchange to oninput and dropping the animation will make the slider feel much slicker.