Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.


Last week's paper roundup (more or less by accident) focused mostly on path dependence of deep learning and the order of feature learning. Going forwards, I've decided to have an explicit focus for each week's roundup. This week's focus is on the structure/redundancy of trained models, as well as linear interpolations through parameter space. 

I've also decided to publish each roundup on Monday morning (Edit: or try to, at any rate).


Residual Networks Behave Like Ensembles of Relatively Shallow Networks

In this work we propose a novel interpretation of residual networks showing that they can be seen as a collection of many paths of differing length. Moreover, residual networks seem to enable very deep networks by leveraging only the short paths during training. To support this observation, we rewrite residual networks as an explicit collection of paths. Unlike traditional models, paths through residual networks vary in length. Further, a lesion study reveals that these paths show ensemble-like behavior in the sense that they do not strongly depend on each other. Finally, and most surprising, most paths are shorter than one might expect, and only the short paths are needed during training, as longer paths do not contribute any gradient. For example, most of the gradient in a residual network with 110 layers comes from paths that are only 10-34 layers deep. Our results reveal one of the key characteristics that seem to enable the training of very deep networks: Residual networks avoid the vanishing gradient problem by introducing short paths which can carry gradient throughout the extent of very deep networks. 

My opinion:

This paper suggests that neural nets are redundant by default, which gives some intuition for why it's often possible to prune large fractions of a network's parameters without much impact on the test performance, as well as the mechanism by which residual connections allow for training deeper networks: residual connections allow shallow nets to communicate directly with the input / output space, so they allow for deep nets to be built from ensembling shallow nets. 

I think it also points away from neural nets implementing a Kolmogorov or circuit simplicity prior.

On the Effect of Dropping Layers of Pre-trained Transformer Models

Transformer-based NLP models are trained using hundreds of millions or even billions of parameters, limiting their applicability in computationally constrained environments. While the number of parameters generally correlates with performance, it is not clear whether the entire network is required for a downstream task. Motivated by the recent work on pruning and distilling pre-trained models, we explore strategies to drop layers in pre-trained models, and observe the effect of pruning on downstream GLUE tasks. We were able to prune BERT, RoBERTa and XLNet models up to 40%, while maintaining up to 98% of their original performance. Additionally we show that our pruned models are on par with those built using knowledge distillation, both in terms of size and performance. Our experiments yield interesting observations such as, (i) the lower layers are most critical to maintain downstream task performance, (ii) some tasks such as paraphrase detection and sentence similarity are more robust to the dropping of layers, and (iii) models trained using a different objective function exhibit different learning patterns and w.r.t the layer dropping. 

My opinion:

(see below)

Of Non-Linearity and Commutativity in BERT

In this work we provide new insights into the transformer architecture, and in particular, its best-known variant, BERT. First, we propose a method to measure the degree of non-linearity of different elements of transformers. Next, we focus our investigation on the feed-forward networks (FFN) inside transformers, which contain 2/3 of the model parameters and have so far not received much attention. We find that FFNs are an inefficient yet important architectural element and that they cannot simply be replaced by attention blocks without a degradation in performance. Moreover, we study the interactions between layers in BERT and show that, while the layers exhibit some hierarchical structure, they extract features in a fuzzy manner. Our results suggest that BERT has an inductive bias towards layer commutativity, which we find is mainly due to the skip connections. This provides a justification for the strong performance of recurrent and weight-shared transformer models. 

My opinion:

These two papers tell us similar things about the function of transformer layers. They are fairly redundant, especially the later layers, with adjacent layers being more similar to each other than they are to more distant layers. This suggests transformers are also fairly ensembly. 

Though the ROME paper implies that there's an asymmetry between what the early layers do (store and look up key-value memories) and what the later layers do (query early layer value outputs through their self-attention connections). Perhaps that's why the later layers cause a lower reduction in performance after being pruned?

Code here.

The Role of Permutation Invariance in Linear Mode Connectivity of Neural Networks

In this paper, we conjecture that if the permutation invariance of neural networks is taken into account, SGD solutions will likely have no barrier in the linear interpolation between them. Although it is a bold conjecture, we show how extensive empirical attempts fall short of refuting it. We further provide a preliminary theoretical result to support our conjecture. Our conjecture has implications for lottery ticket hypothesis, distributed training, and ensemble methods. 

My opinion:

Past work demonstrated that image models with different random initializations converged to minima that have simple, but not linear, paths of constant loss between them. This paper argues that these paths become linear if you account for possible permutations of the model's weights. I think this is plausible for the sorts of wide image models trained on very limited data that people typically do mode connectivity experiments on, but I doubt it will hold for large language models or other systems that could plausibly scale to AGI. 

The best parts of this paper are the empirical experiments. They find that the loss barrier on linear interpolations between trained models decreases with width and increases with both depth and the difficulty of the training data. 

These results make sense to me. Linear mode connectivity requires that every linear combination of two networks also be an effective network. The wider the network, the more degrees of freedom there are in its hidden representations, and so the more room to avoid destructive interference between the hidden representations of the two networks. The harder the dataset, the more of that representational capacity each network's hidden representations will take up.

A network's forwards pass essentially sends a "message" from the input layer to the final layer, using an encoding defined by the weights. Linear mode connectivity between two networks implies that the encoding schemes of the two networks have little interference, such that both their "messages" can be sent simultaneously through the same channel. Note that word embeddings are often very anisotropic (lie in a narrow cone, rather than spread across the full representation space), and that neural net embeddings often have large first principle components that vary across initializations. Speculatively, these two patterns might contribute to the relative lack of interference between models during interpolation.

This paper's "theoretical result" is for MLPs with one hidden layer at random initialization (as in, no training at all). I was expecting them to use the NTK to approximate the training process, but it's just a probabilistic consequence of permutations increasing exponentially with layer width.

If anything, the fact that there's linear mode connectivity between untrained MLPs suggests that two models having linear mode connectivity tells you relatively little about their true degree of functional similarity.

This paper also investigates the impact of weight permutations on linear mode connectivity. They have less extensive empirical investigations, but have an effective algorithm for finding good weight permutations to support linear interpolations. 

Note that neither paper uses Adam, which Analyzing Monotonic Linear Interpolation in Neural Network Loss Landscapes (below) finds can break a different type of linear interpolation property.

Code here.

Loss Surface Simplexes for Mode Connecting Volumes and Fast Ensembling

With a better understanding of the loss surfaces for multilayer networks, we can build more robust and accurate training procedures. Recently it was discovered that independently trained SGD solutions can be connected along one-dimensional paths of near-constant training loss. In this paper, we show that there are mode-connecting simplicial complexes that form multi-dimensional manifolds of low loss, connecting many independently trained models. Inspired by this discovery, we show how to efficiently build simplicial complexes for fast ensembling, outperforming independently trained deep ensembles in accuracy, calibration, and robustness to dataset shift. Notably, our approach only requires a few training epochs to discover a low-loss simplex, starting from a pre-trained solution. Code is available at this https URL.

My opinion:

This paper presents an interesting method for estimating the geometry of the low-loss solution manifold found by SGD. Starting at a solution found by SGD, they essentially grow a maximum dimensional simplex whose vertices are solutions in parameter space and are constrained to have low loss. They then repeat this with many SGD solutions to build a collection of simplexes that approximate the low loss manifold's geometry. This lets them lower-bound the dimensionality of the low-loss manifold as being at least 10 dimension, though the authors are unable to create simplexes of more than 10 dimensions that have non-trivial hypervolume.

The paper also finds that averaging over the vertices of simplexes can improve model robustness, further evidence that mode connected networks can still implement different functions with different generalization behaviors.

Code here.

Linear Connectivity Reveals Generalization Strategies

It is widely accepted in the mode connectivity literature that when two neural networks are trained similarly on the same data, they are connected by a path through parameter space over which test set accuracy is maintained. Under some circumstances, including transfer learning from pretrained models, these paths are presumed to be linear. In contrast to existing results, we find that among text classifiers (trained on MNLI, QQP, and CoLA), some pairs of finetuned models have large barriers of increasing loss on the linear paths between them. On each task, we find distinct clusters of models which are linearly connected on the test loss surface, but are disconnected from models outside the cluster -- models that occupy separate basins on the surface. By measuring performance on specially-crafted diagnostic datasets, we find that these clusters correspond to different generalization strategies: one cluster behaves like a bag of words model under domain shift, while another cluster uses syntactic heuristics. Our work demonstrates how the geometry of the loss surface can guide models towards different heuristic functions.

My opinion:

The previous papers indicated that linearly connected basins could contain models implementing different functions with different generalization behavior. In contrast, this paper trains BERT models and finds they enter two basins without linear mode connectivity between them, and that these basins correspond to different functional solutions to the training data, which solve the training data using very different strategies.

This paper does not check for linear mode connectivity under weight permutations, but I wouldn't be surprised if the two basins remained unconnected even after allowing for permutations.

Code here.

Analyzing Monotonic Linear Interpolation in Neural Network Loss Landscapes

Linear interpolation between initial neural network parameters and converged parameters after training with stochastic gradient descent (SGD) typically leads to a monotonic decrease in the training objective. This Monotonic Linear Interpolation (MLI) property, first observed by Goodfellow et al. (2014) persists in spite of the non-convex objectives and highly non-linear training dynamics of neural networks. Extending this work, we evaluate several hypotheses for this property that, to our knowledge, have not yet been explored. Using tools from differential geometry, we draw connections between the interpolated paths in function space and the monotonicity of the network - providing sufficient conditions for the MLI property under mean squared error. While the MLI property holds under various settings (e.g. network architectures and learning problems), we show in practice that networks violating the MLI property can be produced systematically, by encouraging the weights to move far from initialization. The MLI property raises important questions about the loss landscape geometry of neural networks and highlights the need to further study their global properties. 

My opinion:

Prior work indicates that the linear path from initialization to the converged solution has monotonically decreasing loss. This paper tests monotonic linear interpolation (MLI) across various training configurations, finding that it often fails to hold for networks trained with Adam or with large SGD learning rates. The find that networks that move further from initialization tend to have more curved optimization trajectories, and that the MLI property is less likely to hold for these networks.

Code here.

Revisiting Model Stitching to Compare Neural Representations

We revisit and extend model stitching (Lenc & Vedaldi 2015) as a methodology to study the internal representations of neural networks. Given two trained and frozen models  and , we consider a "stitched model'' formed by connecting the bottom-layers of  to the top-layers of , with a simple trainable layer between them. We argue that model stitching is a powerful and perhaps under-appreciated tool, which reveals aspects of representations that measures such as centered kernel alignment (CKA) cannot. Through extensive experiments, we use model stitching to obtain quantitative verifications for intuitive statements such as "good networks learn similar representations'', by demonstrating that good networks of the same architecture, but trained in very different ways (e.g.: supervised vs. self-supervised learning), can be stitched to each other without drop in performance. We also give evidence for the intuition that "more is better'' by showing that representations learnt with (1) more data, (2) bigger width, or (3) more training time can be "plugged in'' to weaker models to improve performance. Finally, our experiments reveal a new structural property of SGD which we call "stitching connectivity'', akin to mode-connectivity: typical minima reached by SGD can all be stitched to each other with minimal change in accuracy.

My opinion:

This paper presents an interesting tool for comparing the features extracted by different models. They take the top and bottoms of two trained models, freeze their parameters, then "stitch" them together by learning a linear transformation between the embedding spaces of the two models. They find model stitching works well across architectures, datasets and training processes.

The fact that so many models can be stitched together lends support to feature universality and natural abstractions, and suggests different architectures are reasonably consistent in the types of features they extract. Similarly, the fact that models at different points in the training process can be stitched together suggests a degree of stability in the model's representations throughout training. 

Also, this paper suggests GPT-2 word embeddings and human neurological activations during language processing can be similarly stitched together with a linear transform.

BERTs of a feather do not generalize together: Large variability in generalization across models with similar test set performance

If the same neural network architecture is trained multiple times on the same dataset, will it make similar linguistic generalizations across runs? To study this question, we fine-tuned 100 instances of BERT on the Multi-genre Natural Language Inference (MNLI) dataset and evaluated them on the HANS dataset, which evaluates syntactic generalization in natural language inference. On the MNLI development set, the behavior of all instances was remarkably consistent, with accuracy ranging between 83.6% and 84.8%. In stark contrast, the same models varied widely in their generalization performance. For example, on the simple case of subject-object swap (e.g., determining that "the doctor visited the lawyer" does not entail "the lawyer visited the doctor"), accuracy ranged from 0.00% to 66.2%. Such variation is likely due to the presence of many local minima that are equally attractive to a low-bias learner such as a neural network; decreasing the variability may therefore require models with stronger inductive biases. 

My opinion:

This paper compares the generalization behavior of different BERT finetuning runs, where only the classification head initialization and training data order are varied during the training, though they do backprop through the entire BERT model during training. They find that different training runs have very different generalization behaviors (when evaluated on probing data specifically crafted to highlight which linguistic structures the models had learned to use to make classifications), but very similar within distribution behaviors. 

This implies that using test data from the same distribution as the training data is not enough to pick up on differences in the generalizations that different trained networks learn. If the authors had just used test data to evaluate the finetuned models' generalizations, they'd have concluded that there was very little variability between the models.

(Thanks to Zac Hatfield-Dodds for bringing this paper to my attention)


My impression after reading these and similar papers is that results from "toy" settings often do not reflect those found in more realistic settings. A lot of these "optimization geometry" style investigations use very wide models that are massively undertrained on relatively easy datasets (usually of images). The consequence is that they operate in regimes where NTK and neural network Gaussian processes give good approximations of network training dynamics. However, these approximations do not hold well for larger models solving harder problems. 

This raises the question of whether the empirical results we get from large networks such as GPT-3 will extend to whatever networks eventually implement AGI-level capabilities. I am hopeful that the differences in learning trajectories between toy networks and more powerful systems represent a single "phase transition", primarily caused by moving out of the NTK regime. 

I do expect there is another "phase transition" in the inductive biases of AGI learning trajectories when the AI system becomes capable of actively trying to refine its own abstractions to improve its future thinking. Though, I think we can still study low-capabilities networks with such dynamics by studying active learning for GPT-3 levels systems.

Anyways, I hope readers find these papers useful for their own research. Please feel free to discuss the listed papers in the comments or recommend additional papers to me.


For next week's roundup, I'm thinking the focus will be on using interpretability tools to guide a neural net's learning process. There's apparently a fair bit of work in this space. E.g., this review paper.

My other candidate focuses are:

  • Diffusion models
  • SGD inductive biases
  • Controllable text generation
  • Techniques for chain of thought language models
  • Structure and content of language model internal representations

Let me know if there are any topics you're particularly interested in.

New Comment
2 comments, sorted by Click to highlight new comments since:

Really cool. I read some of these kinds of papers last week, but this is better context on the topic. Redundancy seems like evidence in favor of a narrow loss basin, but e.g. the fact that fine-tuned BERT models generalize very differently is evidence of multiple local minima. Your guess that linear mode connectivity works in simple image classification domains but not in language models seems like the most likely answer to me, but I would be interested to see it tested.

Very useful, thank you!