This is a linkpost for

From Twitter

Transformers become more ‘tree-like’ over the course of training, representing their inputs in a more hierarchical way. The authors find this by projecting transformers into the space of tree-structured networks. [Stanford, MIT] 


When trained on language data, do transformers learn some arbitrary computation that utilizes the full capacity of the architecture or do they learn a simpler, tree-like computation, hypothesized to underlie compositional meaning systems like human languages? There is an apparent tension between compositional accounts of human language understanding, which are based on a restricted bottom-up computational process, and the enormous success of neural models like transformers, which can route information arbitrarily between different parts of their input. One possibility is that these models, while extremely flexible in principle, in practice learn to interpret language hierarchically, ultimately building sentence representations close to those predictable by a bottom-up, tree-structured model. To evaluate this possibility, we describe an unsupervised and parameter-free method to functionally project the behavior of any transformer into the space of tree-structured networks. Given an input sentence, we produce a binary tree that approximates the transformer's representation-building process and a score that captures how "tree-like" the transformer's behavior is on the input. While calculation of this score does not require training any additional models, it provably upper-bounds the fit between a transformer and any tree-structured approximation. Using this method, we show that transformers for three different tasks become more tree-like over the course of training, in some cases unsupervisedly recovering the same trees as supervised parsers. These trees, in turn, are predictive of model behavior, with more tree-like models generalizing better on tests of compositional generalization.


Note: The datasets they use aren't typical human text. See page 5.

I'm interested in discussions and takes from people more familiar with LMs. How surprising/interesting is this?

New Comment
2 comments, sorted by Click to highlight new comments since:

I'm not surprised that if you investigate context-free grammars with two-to-six-layer transfomers you learn something very much like a tree. I also don't expect this result to generalize to larger models or more complex tasks, and so personally I find the paper plausible but uninteresting.

I do think the paper adds onto the pile of "neural networks do learn a generalizing algorithm" results.

if you investigate context-free grammars

Notably, on Geoquery (the non--context-free grammar task), the goal is still to predict a parse tree of the natural sentence:

Given that neural networks generalize, it's not surprising that a tree-like internal structure emerges on tasks that require a tree-like internal structure. 

I also don't expect this result to generalize to larger models or more complex tasks, and so personally I find the paper plausible but uninteresting.

Since we don't observe this in transformers trained on toy tasks without inherent tree-structure (e.g. Redwood's paren balancer) or on specific behaviors of medium LMs (induction, IOI, skip sequence counting), my guess is this is very much dependent on how "tree-like" the actual task is. My guess is that some parts of language modeling are indeed tree-like, so we'll see a bit of this, but it won't explain a large fraction of how the network's internals work. 

In terms of other evidence on medium-sized language models, there's also the speculative model from this post, which suggests that insofar as this tree structure is being built, it's being built in the earlier layers (1-8) and not later layers.