TLDR; We report our intermediate results from the AI Safety Camp project “Mechanistic Interpretability Via Learning Differential Equations”. Our goal was to explore transformers that deal with time-series numerical data (either infer the governing differential equation or predict the next number). As the task is well formalized, this seems to be an easier problem than interpreting a transformer that deals with language. During the time of the project, we leveraged various interpretability methods for the problem at hand. We also obtained some preliminary results (e.g., we observed a pattern similar to numerical computation of the input data derivative). We plan to continue working on it to validate and extend these preliminary results.     

Introduction

Mechanistic interpretability tries to understand the algorithms implemented by a neural network. This requires inferring the features in the activation patterns of the transformer corresponding to particular patterns in the data it learns. Often this approach is quite successful. Perhaps, the most popularized example is the Golden Gate Claude, where the Anthropic team succeeded in finding a representation of the Golden Gate Bridge in Claude’s “mind” and tuned it so that the model became “obsessed” with the Golden Gate Bridge. Despite such impressive results, mechanistic interpretability is far from being a completely solved field. One may notice that the problem we are trying to solve has two levels of difficulties - for one we need to figure out not only how transformers represent the features of the data it learns, but also what these features are, bearing in mind human language is complex and not well formalized. 

In this project, we tackle these two problems separately. Instead of dealing with the complexity of the human language by using LLMs, we study mathematical transformers. We leveraged transformers that deal with time series: the ODEFormer trained to predict the symbolic form of the ordinary differential equation based on the data points from its solution, and Hugging Face Time-Series Transformer that predicts the next numerical value in the sequence. Since we are acquainted with the underlying mathematical problem, we need only figure out how the solution process is represented in the transformer's activation pattern, which seems to be a much more tractable problem.

One may wonder how understanding of the representation of ordinary differential equations can help with advancing mechanistic interpretability for LLMs. There are three potential benefits. First, learning a toy model can help advance understanding in a more complicated model, as they often share the same features. Second, if the natural abstraction argument is valid, we can expect abstractions that transformers learn from the real world to correspond to particular mathematical patterns. Finally, a fundamental understanding of the underlying gears of the transformer will likely be useful for interpretability in the long term, even if not now, like the fundamental understanding of electromagnetism did not immediately lead to practical benefits, but did later. 

During three months of the project, we succeeded to set up the interpretability tools for both ODEFormer and Hugging Face Time series transformer, and obtained a few preliminary results for the ODEFormer (see the next section).

Results 

Hypothesis

For most of our work, we used the ODEFormer, an 86M parameter encoder-decoder transformer with a beam search that predicts a symbolic form of the first-order autonomous ordinary differential equation   (i.e., equation of type ,  where the function in the right hand side does not have explicit time dependence) from the numerical solution of this equation. 
How does ODEFormer solve the problem of inferring differential equation? We formed a hypothesis that it does so in a few steps. First, it takes a numerical derivative of the data , presumably in the encoder part of the transformer. Second, it tries to apply various analytical functions to the data  and compares it to the derivative during the beam search process in the decoder. When  is close enough to  (e.g. R^2 is smaller than a threshold), the beam search is over and ODEFormer outputs solution finishing with end-of-the-statement token. 
An alternative to this hypothesis would be that ODEFormer explores the data holistically, observing multiple different patterns (like periodicity, asymptotics, higher-order-derivatives patterns, or potentially even patterns that do not correspond to something humans usually pay attention to) and classifying equations based on these patterns.

The way to validate the derivative hypothesis is to explicitly find this algorithm, namely, the derivative calculation in the encoder and comparison with the different functions by computing R^2 in the decoder. In our preliminary results, we find the indication of derivative calculation in simple cases, but we have not yet observed the R^2 calculation. 

Derivative inference

To study the effect and calculation of the derivative, as a model system we choose data produced by the logistic equation   whose solution is a well-known sigmoid function . This allows us to pay special attention to the region of high derivative, its width and location.

First, we explore the info-weighted attention matrices of the heads in the encoder layer (see the Appendix for the details and motivation of adding info-weighted attention). We observe that one of the heads (namely layer 2 head 7) seemed to pick up on differentiation or the portion of the curve where the slope is not 0, by having the attention mostly directed to the region of high derivative (see Fig. 1).

Fig. 1: Attention head 7 layer 2 pays attention to the region with high derivative while head 8 at layer 0 pays attention at the trajectory.

 

Second, we used linear probes (see Appendix) to predict both the time-point of maximum derivative of sigmoid functions, and the value of the maximum derivative from the activations of the encoder layer. The probes were able to predict the time-point of the maximum derivative successfully, which indicates that this information is linearly represented in the activations (see Fig. 2).  However, somewhat surprisingly, probes were then not able to predict the value of the maximum derivative well (Fig.2). This may mean that we did not train on enough data, which we may address in future work. Another option is that the absolute value is stored in a complicated way in a transformer (i.e. mantissa and exponent separately), which obstructs successful inference using a linear probe.

 


While these preliminary results seem to support our derivative hypothesis, they are not conclusive enough to eliminate alternative explanations. Indeed, we tried only logistic equation in this analysis, and so far we were not able to infer the value of the derivative, only the time-point of maximal derivative and the width of the transition region– and this can be obtained without explicit derivative calculation everywhere.  So we need to perform more tests to validate this hypothesis.

 

Inferring R^2 score

To investigate the second part of the suggested algorithm (comparing the derivative with various function forms) we use linear probes to predict the R^2 score of the trajectory under the predicted equation. However, the results offer little evidence to suggest that the R^2 score is linearly encoded in the model activations in any of the layers.
We expected the final decoder layer to have the best performance. This is because this is the layer 'closest' to the output equation, from whose trajectory we compute the R^2 score we want to predict. However, this turned out not to be the case, as can be seen above (looking at layer index 15). In fact, layer 4 of the decoder (i.e. layer index 7) shows the greatest performance, which is somewhat surprising.
While the results are somewhat negative, part of the issue could be that this experiment only considered 1D ODEs generated using the same method as the ODEFormer's pre-training data. The ODEFormer is typically able to handle such equations quite well, usually achieving R^2 scores close to 1. Then, when data which is not fit well appears, the probe is unable to predict the correct value, since it has almost only seen R^2 values close to 1.

R^2 score prediction vs ground truth for a probe trained using a numerical solver on activations from layer index 7. Note than most of the data is clustered around (1,1), and that the lowest probe prediction is around 0.68. The Spearman coefficient for this plot is 0.75.

 

Other results

In addition to the validation of the derivative hypotheses, we explored how ODEFormer encodes equation type classification. We observe a high performance of linear probes at the decoder layer for classifying between linear and hyperbolic equations, as well as equations in different dimensions. Using Sparse Auto-encoders, we see that there are features that are active only at a certain equation families, but we have not figured out yet the general rule for them. For more specific task, following Logit Lens we observed at which layer the sign of the term in the linear differential equation gets predicted, and how this prediction is inherited through the layers of the transformer. Finally, we also explore the performance of the probes at the inferring eigenvalues for 2D linear differential equations to test a hypothesis that ODEFormer first infers eigenvalues from the data, and then constructs the equation from these eigenvalues. If this were the case, we would expect higher accuracy or inferring eigenvalues than of the equation coefficients at some of the lower layers of the ODEFormer. So far we do not observe this effect, what leaves the eigenvalue hypothesis unsupported. See Appendix for more details on these results.

Artifacts

As part of this research, we developed and open-sourced a set of tools to support reproducibility and further exploration:

Future plans 

As we have our interpretability tools finally working and producing preliminary results, we would like to validate them and do more exploration. Do we indeed see the computation of the derivative? Can we observe it in the activation pattern? Can we pinpoint the exact mechanism for how it is calculated (e.g. is it a simple finite difference scheme, or some higher-order differential scheme)?  

Would we see the same pattern - the calculation of the derivative - for the time-series transformer? 

Moreover, we hope to interpret other potential feature that distinguishes different types of equations, found by SAE, as well as to find other features. 

After we understand the mechanism of transformers above, we would like to explore whether there are any similarities in activation patterns for mathematically trained transformers and LLMs (like Llama, for example). There, it was demonstrated that LLMs (Llama and GPT-3) can do a time-series prediction task without any additional fitting or prompts. Do LLMs use the same algorithm to perform the time-series prediction task as a specialized transformer? 

Finally, understanding what patterns in LLMs are activated during the time-series prediction task may shed some light on language prediction, if in certain circumstances we observe activation of these patterns even in language tasks.

As the project is still in its intermediate stage, we are quite flexible and looking forward to your feedback.

Research Team

  • Valentin Slepukhin - Research Lead
  • Ayo Akinkugbe - Research Co-Lead
  • Probing Team
    • Dylan Ponsford - Lead
    • Helen Saville
    • Axel Ahlqvist
    • Melwina Albuquerque
  • Sparse Auto-encoder Team
    • Eduard Kovalets - Lead
    • Tommaso Mencattini
    • Georgios Nikolaou
  • Logit Lens and Attention Lens Team
    • Soumyadeep Bose
    • Joep Storm
    • Mufti Taha
  • Time Series Transformer Team
    • Murshed Al Amin - Lead
    • Varun Piram
    • Syed Irtiza
    • Abhik Rana
  • Activation Maximization Team
  • Utkarsh Priyadarshi

Appendix

In our project, each team leveraged a different method of mechanistic interpretability to understand the inner workings of the ODEFormer and the time series transformer. Our goal with these methods was to better understand how and why these toy models work by uncovering patterns, structures or circuits within. We had a team devoted to studying the time series transformer and also looked into the following  methods specifically for the ODEformer - Attention lens, Logit Lens, Probing and Sparse Autoencoders. In earlier efforts, we also looked into SHAP but did not find this method fruitful. 

Time-Series transformer

We studied the time series transformer model via HuggingFace’s leveraging univariate and multivariate datasets for training. Rather than inferring a governing law producing numerical values at each time point, this model predicts the next value. We focused mainly on using Sparse autoencoders on activations to understand what the time-series model learned after training. 

 Results for time-series transformer

We succeeded to train the univariate and multivariate time-series transformer to predict the data generated by the simple functions (linear combinations of trigonometric functions, polynomials, exponential functions and hyperbolic functions). We used sparse autoencoders to infer the features corresponding to different types of functions, but did not interpret our results yet.

ODEFormer

For most of our work, we used the ODEFormer, an 86M parameter encoder-decoder transformer that predicts a symbolic form of the first-order autonomous ordinary differential equation   (i.e., equation of type ,  where the function in the right hand side does not have explicit time dependence) from the numerical solution of this equation. 

Other results for ODEFormer

This section gives a brief overview into each method used for interpreting each model.

Probing 

Probing is a method that involves training small auxiliary machine learning models (probes) to predict specific properties from the internal representations of a larger model. The idea is to isolate whether particular features are encoded in the hidden layers of a model. A simple probe (e.g. a logistic regression classifier) is trained on representations from various layers to determine how linearly accessible these properties are. We leverage probing to track the flow of information through a model’s layers and identify where specific knowledge is stored or transformed.

Intermediate Results

Equation type classification

We used binary probes for a simple classification task between two types of equations (exponential  and hyperbolic ). 

Classification accuracy between exponential and hyperbolic equations

Layers indexed 0-3 represent the encoder while layers 4-15 represent the decoder. The classification accuracy increases through the encoder and is basically perfect throughout the decoder. This is despite the fact that ODEFormer has some performance issues on the exponential samples, and only reconstructs a trajectory with R^2 >= 0.5 for 40% of the samples.

We also trained a probe to classify one- and two-dimensional systems. The probe accuracy was fairly high, after an initial decrease in the first three decoder layers; however we had expected near-perfect accuracy, as we expected ODEFormer would easily detect the system dimension from the dimension of the input trajectory. It may be that, by the decoder layers, such information is not required to be linearly represented in the activations. Also surprisingly, ODEFormer itself occasionally gets the dimensionality of predicted ODE systems wrong, but it is unclear why this happens.

Classification accuracy between one- and two-dimensional systems (decoder layers only)

 

Eigenvalues prediction

We also explored 2D linear systems on the form 

One hypothesis is that the ODEFormer uses the eigenvalues of the matrix 

as those determine the behaviour of the system. We also predicted the coefficients   α, β, γ and δ as a baseline and comparison as these are necessary to represent in order produce right hand expression of the system.

The performance for the eigenvalues is worse than for the coefficients. This is some evidence against the hypothesis that the ODEFormer calculates eigenvalues in order to solve ODEs. We generated 10k samples with different combinations of the coefficients α ∈ [-2,2], β ∈ [-2,2], γ ∈ [-2,2], δ ∈ [-2,2]. Then we only included the 2732 samples, the ODEFormer performed well on (R^2 >= 0.9) as we are particularly interested in the ODEFormer's behaviour when it is successful.

 

In summary the results from the probing experiments give negative results for the hypothesis that the network is searching through different expressions while evaluating the performance. We also got negative results for the network calculating eigenvalues of the coefficient matrix. Finally we got some positive results for the network calculating the location of maximum differention. This was further explored with Attention Lens.

Sparse Autoencoder (SAE) 

Sparse autoencoder is an unsupervised neural network architecture designed to learn compressed representations of input data while enforcing sparsity in the hidden layers. In our context for interpreting the ODEformer, we trained SAEs on the internal activations of the ODEformer model so that the encoder learned a small number of active neurons (features) that could reconstruct the original activations. This sparsity constraint encourages the network to learn disentangled and interpretable features, rather than dense, overlapping representations.

Once trained, these sparse features were analyzed to understand what kind of patterns or concepts the original model encodes internally.

Intermediate Results 

We succeeded to train Sparse Auto Encoders to analyze the features corresponding to particular equation type. Our analysis of ODEFormer using Sparse Autoencoders revealed several interesting and unexpected findings about how the model processes differential equations:

  • Non-interpretable feature activations: Despite our initial hypothesis, plotting top feature activations along trajectories did not yield clearly interpretable patterns. The features extracted by the SAE did not correspond to obvious mathematical concepts like derivatives or specific dynamics patterns as we had anticipated.
  • Lack of structured feature organization: The t-SNE visualizations of the latent feature space showed no clear structure organized by equations, features, timesteps, activation values, or dynamical system properties. This suggests that the representation scheme used by ODEFormer is more complex and distributed than we had expected.
  • Feature stability across system variations: Top activating features (for example, features 971, 786, and 595) remained remarkably stable not only under significant parameter changes but also across entirely different dynamical system types. This unexpected stability suggests these features may capture fundamental properties that transcend specific equation forms.
  • Boundary-defining features: These stable features appeared to be active on different sides of the manifold in the t-SNE visualization, essentially "bounding it with itself." This indicates they may play a structural role in organizing the representation space rather than encoding specific mathematical properties.
  • Resilience to feature modification: The ablation and modification experiments revealed that ODEFormer's solution generation remained stable even under severe modifications to activations in both residual and MLP layers. This resilience suggests that critical processing may occur elsewhere in the model architecture.
  • Processing allocation hypothesis: Our findings point to the possibility that the actual processing of dynamical information happens primarily in attention layers or in the decoder layer, which we had not yet studied with SAE.

 

Logit Lens & Attention Lens

  1. Logit Lens (nostalgebraist, 2020) is a method that directly accesses each layer's outputs and allows us to study what the model predicts at each processing step. It is based on the insight that the hidden vectors passed between the intermediate layers of a model have the same dimension as the output vectors, and applying the same unembedding to these intermediate vectors, we obtain the predictions made by these layers. Thus, we are able to 'see-through' the model's operations and study how the output is formed across the layers. We can for example see how much each layer contributes to the final output, how the predictions change across the layers, what the top predicted tokens at any stage are as well as how likely the model considers them to be.
  2. Attention Lens refers to the method where we plot attention patterns of  various induction heads. Based on these patterns we experiment with certain systems having predetermined mathematical aspects/qualities hoping to get some representation of said aspects/qualities from the induction heads simply by observing the attention patterns. We also experiment with “value-weighted” attention patterns and observe some interesting phenomenon that remains unexplained. Basically, this method lets us see how or what the model is “thinking” when figuring out the symbolic expression of a certain n-dimensional trajectory.

We developed some key functions shared in this library that allowed us to directly plot token charts, and obtain intermediate tokens/logits, and also attentions (value-weighted as well as the usual ones). 

Intermediate Results

  • Sign Prediction in Simple 1D Systems

We have several interesting discoveries related to using Logit Lens on a sign prediction example. We consider a one-dimensional decreasing exponential function, and show the results using Logit Lens in the figure below (Fig. LALens1).

Fig. LALens1: A compact chart showing token evolution across layers and beams.

The figure shows how tokens and the corresponding confidence scores vary across beams and decoder layers. For this experiment, we had a simple 1-dimensional decreasing exponential system primarily because we had observed earlier this system results in a simple 6-token output. Also, we opted for a low beam size (2) and temperature (0.1) for simplicity purposes. A general observation is that in the initial layers across all beams, the model still “predicts” the previous token, as a result of the autoregressive nature of the model. However, this finding seems to not hold true while predicting a token following a constant. In the successive layers the probability is spread across several tokens which is why we see a more uniformly blue distribution, and in the last few layers the model becomes more confident on the right token. However, in the token_3 subplot, the prediction of the constant does not become very confident even in the last few layers like in the other subplots. This is an obvious result because these constants are in the order of 10-4 and small changes in magnitude does not really affect the overall expression that significantly.

Next, we zoom in a bit on the sign prediction. We have observed that in cases of both decreasing and increasing exponential systems, decoder layer 6 predicts ‘+’ which is corrected to ‘-’ in case of a decreasing system only in decoder layer 8. This is a bias towards ‘+’ presumably because most numbers in the training data were positive. We have observed the attention patterns in the various heads as well and found some heads to be capturing the direction of maximum to minimum magnitude or vice versa in the trajectory which according to us was key in determining the sign.

Fig. LALens2: The second row from the top shows interesting patterns.

In Fig. LALens2, we can see that the rows corresponding to the value 1 on the y-axes show that pattern of minimum to maximum or vice versa. Row 1 represents the token ‘+’ or ‘-’. These three patterns are from the attention heads of layers in which the ‘+’ or ‘-’ was being predicted for the first time. The single dominant column of attention will be discussed later. We have also looked at the logits of the ‘+’ and ‘-’ tokens and their evolution through the layers.

Fig. LALens3: Left one is for the increasing system, while the right one is for the decreasing system.

There is no immediate difference between the two plots, however, since ‘+’ was being predicted from layer 6 onwards, and logit 389 also seems to be more activated from layer 6 onwards (here 7 because naming begins from 1). Thus, we conclude that most of the information about the sign is being carried by logit 389 alone.

Now, let’s discuss the attention patterns in more detail. Broadly, we have found concrete proof that null attentions or attention sinks exist in the ODEFormer, a point that is being mostly attended to but seem to not be any special at first glance, and attention heads that capture various mathematical qualities of the input trajectory that is helpful for the model to predict the symbolic expression.

  • N1010 as “Default” Token 

During our experiments, we have seen (oftentimes with higher beam sizes) that after we get an “<EOS>” token in one beam, but other beams have not yet yielded an “<EOS>” token, the beam gives an “add” token or another “<EOS>” token as the next token. If all other beams have still not yielded “<EOS>”, then the beam that has finished predicting an expression will continue to give the “N1010” constant token as a kind of “default” token. For example, consider the following three plots in Fig. LALens4 below for predicting three consecutive tokens. 

Fig. LALens4: N1010 is predicted after expression prediction ends in some beams.

In the figure’s first subplot, we can see that beams 1-4 and 7 and 9 gave the “<EOS>” token as output, but other beams did not. Due to this, when the other beams carry on with the next token prediction, these 6 beams either give the “<EOS>” or “add” token as output. And if the other beams still have not finished, these 6 beams continue giving the “N1010” token as a “default” token till all other beams give the “<EOS>” token hence completing the process of token generation, after which post-processing takes place.

  • Existence of Null Attentions

We have observed that Attention sinks are present in the attention plots. These are input trajectory points that are more attended to compared to other points, but they do not transfer much information/value as seen from the value-weighted attention plots. What is value-weighted attention? Taken from “A Mathematical Framework of Transformer Circuits” (Elhage et al.) this is a method where we scale the attention in the plots by how much information is transferred which is obtained from the value vectors. Value-weighted attention shows us exactly how much value is being transferred by attending to the input points. So, if a certain point is being mostly attended to by all or most tokens (or the points themselves in case of encoder self-attention) in the usual attention plots but not much in the value-weighted ones, it means that point is not contributing much to the prediction despite being mostly attended to. Here, we say that point is an attention sink. 

However, even interesting is the discovery of a different class of points which we call MAT points (mostly-attended-to points). 

  • The MAT Saga

These are points that are mostly attended to by all (or most) tokens as seen in the decoder’s value-weighted cross-attention plots. The fact that these points light up almost similarly in the value-weighted plot as in the normal ones, show that these points are indeed transferring information that is needed by the model (specifically the decoder layers) to correctly predict the symbolic expression. For example, let’s consider the attention plot in Fig. LALens5. In the top subplot, we can clearly see that there are a few input points/columns that are more lit up than the others. These points are also transferring information so the attention on them is not completely useless.  

Fig. LALens5: The above one is the normalized summation (a combined view) of all 16 heads.

Lighting up is fun but now you might be questioning what is that one defining mathematical quality of the MAT point/column. We have plotted the norm of all of the points of the harmonic sine-cosine system, and from the plot Fig. LALens6, we see that the MAT point has a very low norm.

Fig. LALens6: MAT column has a lower norm value.

Let’s also take an example now to show the difference between value-weighted attentions and normal attentions. In Fig. LALens7, we clearly see more information transferred by other points but the MAT point is still clearly visible. The attention plots in the figure correspond to the system in the previous figure.

   

Fig. LALens7: The bottom one is the value-weighted attention plot.

We initially observed that a few points or just one point (most of the time) was mostly attended to by all input points or tokens, but this was random. The randomisation was done in the ODEFormer to get random points as input in case the original input was very long. Due to us experimenting with very short trajectories, we did not require the randomisation and on turning it off, we observed that the MAT point is constant. What do we mean? Take ‘x’ number of input points from various systems across dimensionalities and feed it to the ODEFormer to get the predicted symbolic expression. In all cases, the MAT point remains the same. Change the number of input points to ‘y’ and the MAT status will shift. This shifting is also predictable according to what we have seen in our experiments. Roughly the 60% point in the input trajectory is going to get the MAT status. So if we have 25 input points in the trajectory, point 15 is going to become the MAT point. One can think at this moment that this sounds like there was some split internally in the trajectory in the ratio of 60:40, and that the MAT point lighting up seems like the Attention sink of the latter 40% of the trajectory since sinks are mostly the first point. This is a great idea but what we have observed in some cases is that there is no abrupt change in the attention pattern before and after the MAT point. In some cases we see that attention is being gradually increased as we move closer to the MAT point from the left, and then decrease in a same gradual manner when we move past the point. If the hypothesis that there is some kind of internal training and testing “split” we should not have seen that gradual attention build-up around the MAT point. And this pattern is only visible in the value-weighted plots, suggesting that even though points around the MAT point are not that attended to, they do transfer helpful information. 

Additionally, we conducted a few tests to get more insights on the nature of the MAT point. Please note that the attention in the following plots are not value-weighted to clearly differentiate the MAT point.

Fig. LALens8: The MAT status seems to depend on a certain time step and not trajectory value.

Before the experiment we believed that the MAT status is because of the value of the point in the trajectory. However, when we modified the trajectory to have different values for points for the same time step (i.e. made the slope infinity for some portion of the input trajectory) the model seemed to be attending to all of the points more than the others. As seen in the second subplot of Fig. LALens8, the model tries to give the MAT status to all of the points at that one particular time step. However, it does not do that and instead we see a gradual decay in the attention the points receive as we move right. Now, let’s move to the third experiment and subplot, where we changed the MAT point’s value to be something entirely different. In this case, we see some other point get the MAT status. But wait a second! It’s what we call the second MAT point that got the MAT status after we shifted the original first MAT point by some degree. What is the second MAT point you ask? When we look back at the first subplot we see not only the MAT point having higher attention than the others, but there are two more points that get somewhat more attention than the others. We call these the MAT candidates. And as we saw, changing the MAT point’s value to some extent results in one of the candidates getting the MAT status instead. Now, a great question would be by how much should the MAT point be shifted? 

Fig. LALens9: The MAT status seems to depend on a certain time step and not trajectory value.

Let’s take a look at Fig. LALens9. Changing the MAT value by anything lesser than or equal to approximately |0.05| retains most of the usual cross-attentions and does not result in the shifting of the MAT status to some candidate. Increasing the value more than |0.05| however shows a significant dip in the averaged cross-attentions of the MAT point, and the MAT status is given to some other candidate point.

One last thing that we observed in some value-weighted attention heads in several systems, is that there exists some heads in which the MAT point is less attended to than the other points. These heads mostly capture some mathematical aspect of the input trajectory, but even if the MAT point should have some attention logically (perhaps it is the maxima or minima or something else), it does not.

  • Attention Heads capture Mathematical Features

Just like we have seen development of low level and high level feature development in the attention heads of vision transformers, we observe that similar phenomenon occurs in ODEFormer as well. The ODEFormer too has specialized attention heads that capture some mathematical aspects/qualities/features of the n-dimensional input system. We are going to look at the self-attentions of the encoder blocks only since it has input points in both axes and we are interested in seeing what data is “encoded” into the trajectory in the encoder layers, before the decoder layers engage in cross-attention while predicting tokens. We treat the ODEFormer as a multi-modal transformer, because it takes input in the trajectory space and gives output in a well-defined token space.

Now, let’s take a closer look at some of the features in some systems.

 

Some context first: In Fig. LALens10, Sigmoid_a_b means the input trajectory was that of a sigmoid following the equation: 

When changing the values of a and b, i.e. changing the position and steepness of the sigmoid curve, we noticed that one head, encoder layer 2 head 7, seemed to pick up on differentiation or the portion of the curve where the slope is not 0. Besides this, we also noticed that encoder layer 0 head 8 seems to trace the input sigmoid trajectory itself at first glance. On changing the sigmoid curve to go from 0 to 1 (instead of 1 to 0 in the figure) the tracing head does not seem to perfectly trace the trajectory, but the pattern does shift with shifts in the transition portion.

In Sigmoid_3_20 you might notice that the prediction doesn’t really look like a sigmoid curve at all. This is true and we have found there to be some degree to which the ODEFormer can predict the sigmoid curve accurately. Beyond a certain degree of steepness, the ODEFormer derails quickly and badly, telling us that the transformer is unable to handle quick and non-periodic transitions in the input trajectory. Though the model managed to capture significant sigmoid-like features in case of Sigmoid_3_20, it failed completely in case of Sigmoid_8_25 (which is way steeper). There are no meaningful patterns in the attention heads, just noise.

While differentiation is a high-level concept, tracing is a relatively low-level low-effort concept, and we hypothesize this to be the reason why a head in layer 0 traces the trajectory, but differentiation is captured by a head in layer 2. To support this, we also found heads that capture the maximum and minimum in a trajectory in layer 0. These findings might not be true for all systems, but from what we have observed these are true for some simple systems.

While observing the cross-attentions after feeding the ODEFormer the trajectory in Fig. LALens6, we observed a few more attention heads that seemed to capture the minima and maxima.

Fig. LALens11: Heads capturing maxima and minima with/without dimension separation.

While the left attention head in Fig. LALens11, decoder layer 5 head 12, captures the minima and maxima but shows no separation of dimension considering that the input trajectory belongs to a 2D system. On the right however, we see that decoder layer 6 head 4 captures the minima and maxima with the dimension separation. While the pattern in the left head is rare and only comes before we see the pattern in the right head around the middle layers. The pattern with the dimension separation is more common in these kinds of systems.

  • Self-Inhibiting Behaviour of the Encoders

We plotted the OV matrix as obtained by multiplying the output weight matrix with the value weight matrix for each encoder layer, and the results are in Fig. LALens12. The most striking feature in all four encoder layers is the presence of a leading diagonal that is significantly more negative than the other values. In addition to this, the diagonal seems to become more and more negative as we move from layer 0 to the final layer of the encoder.

Fig. LALens12: Presence of a significantly more negative diagonal in the OV matrices.

This coincides with the formation of the MAT point, which begins forming from the second encoder layer, and only becomes more and more pronounced as we go towards the final encoder layers. We think this negative diagonal to be some form of self-inhibition, something like decreasing the effect of itself, and we are yet to ascertain the importance of this self-inhibiting behaviour in the formation of the MAT point (if at all).

This strikingly negative leading diagonal is present in the OV plots of the encoder layers, across different systems.

New Comment
Curated and popular this week