Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

As our project at the Melbourne hackathon on Singular Learning Theory and alignment (Oct. 7-8), we did some experiments to estimate the learning coefficient of the single-layer modular addition task at a basin, an invariant that measures the information complexity (read: program length) of a fully trained neural net. 

We used the recent paper of Lau, Murfet, and Wei as our starting point; this paper estimates provides a stochastic estimate for the learning coefficient (which they denote ) via Langevin dynamics. The thermodynamic quantity measured by   is proven to asymptotically converge to the learning coefficient for idealized singular systems in a beautiful paper by Watanabe

All code for the experiments described can be found in this GitHub repository

Brief results

In our tests, we were pleasantly surprised to find that, for the task of modular addition modulo a prime , the outputs of our implementation of Lau et al.'s SGLD methods are (up to a roughly constant small multiplicative error, less than a factor of 2) in robust agreement with the theoretically predicted results for an idealized single-layer modular addition network[1]

Similar results have to date been obtained for a two-neuron network by Lau et al. and for a 12-neuron network by Chen et al. Our results are the first confirmation for medium-sized networks (between about 500 and 8,000 neurons) of the agreement between the estimate and the theoretical results. 

While our results are off by a small multiplicative factor from the theoretical value for a single modular addition circuit, we discover a remarkably exact phenomenon that perfectly matches the theoretical predictions, namely that the learning coefficient estimate is linear in  for modular addition networks that generalize; this is the first precise scaling result of its kind.

Mean measurements of  on modular addition networks for 5 different values of p: the degree of linearity is remarkable.

In addition, using the modular addition task as a test case lets us closely investigate the ability of the  complexity estimate to differentiate generalization and memorization in neural networks: something that seems to be mostly new (though related to some of the phase transition phenomena in Chen et al.). We observe that while generalization has linear learning coefficient in the prime , memorization has (roughly) quadratic growth in the prime ; this again exhibits remarkable agreement with theory.

Our  measurements for memorization-only networks. These are in remarkable agreement with the theoretically predicted values (here, theory predicts 

The agreement with theory holds for multiple different values of the prime  and multiple architectures. They also have appropriate behavior for networks that learn different numbers of circuits; a situation where other estimators of effective dimension, such as the Hessian-eigenvalue estimate, tend to overestimate complexity. 

Additionally, we show that the dynamic  estimate[2], i.e., the estimate during training, seems to track memorization vs. generalization stages of learning (this despite the fact that the  estimate depends only on the training data). To see this, we use a slight refinement of the dynamical estimator, where we restrict sampling to lie within the normal hyperplane of the gradient vector at initialization, which seems to make this behavior more robust. 

Chart of estimated  over training for an MLP trained on modular addition mod 53. Checkpoints were taken every 60 batches of batch size 64. Hyperparameters for SGLD are . The search was restricted to directions orthogonal to the gradient at the initialization point to correct for measurement at non-minima.

Our dynamic results parallel some of the SGLD findings in Chen et al., which show that dynamic SGLD computations can sometimes notice phase transitions. We were pleasantly surprised to see them hold in larger networks and in the context of memorization vs. generalization. 

Overall, our findings update us to put more credence in the real-world applicability of Singular Learning Theory techniques and ideas. More concretely, we now believe that techniques similar to Lau et al.'s SGLD sampling should be able to distinguish different generalization behaviors in industry-scale neural networks and can be a part of a somewhat robust toolbox of unsupervised interpretability and control techniques valuable for alignment. 

Background

Basics about the learning coefficient 

For an alternative introduction, see Jesse and Stan's excellent post explaining the learning coefficient (published after we had written this section but following a similar approach).

The learning coefficient is a parameter associated with generalization. It controls the first-order asymptotic behavior of the question of "How likely is it that, given a random choice of weights, the loss they produce will be within  of optimum". In other words, how easy is it to generalize the optimal solution to within  accuracy. As  goes to zero, this probability goes to zero polynomially, as an exponent of , so 

(The  instead of  is there for technical reasons.) 

Such a term (usually defined in terms of the free energy: here  is the temperature) occurs more generally in statistical physics (and has close cousins in quantum field theory) as the leading exponent in the "perturbative expansion." In the context of neural nets, the exponent  it is called the learning coefficient or the RLCT ("real log canonical threshold," a term from algebraic geometry). 

The learning coefficient contains "dimension-like" information about a learning problem and can be understood as a measure of the effective dimension or "true dimensionality," i.e., the true number of weight parameters that need to be "guessed correctly" for a neural net to solve the problem with minimal loss. In particular, if a neural net is expanded by including redundant parameters that don't affect the set of algorithms that can be learned (e.g., because of symmetries of the problem), it can be shown that the learning coefficient does not change. Note that if the solution set to a machine learning problem is sufficiently singular (something we will not encounter in this post), the learning coefficient can be larger than the actual dimension of the set of minima[3] and can indeed be a non-integer.

The Watanabe-Lau-Murfet-Wei estimate, 

In fact, the learning coefficient defined as a true asymptote only contains nontrivial information for singular networks, idealized systems that never appear in real life (just as it is not possible for two iterations of a noisy algorithm to give the exact same answer, so it is not possible for a network with any randomness to have a singular minimum or a positive-dimensional collection of minima). However, at finite but small values of temperature (i.e., loss "sensitivity," measured by  as above), the problem of computing the associated free energy (and hence getting a meaningful generalization-relevant parameter at a "finite level of granularity") is tractable. 

The paper of Watanabe that Lau et al. follow gives a formula of this type. The result of that paper depends not only on the loss sensitivity parameter (called , from the inverse temperature in statistical physics literature) but also on , the number of samples. The formula gives an asymptotically precise estimate for the learning coefficient of the neural network on the "true" data distribution, corresponding to the limit as the number of samples n goes to infinity. As n goes to infinity, Watanabe takes the temperature parameter  to zero as . Lau et al.'s paper sets out to perform this measurement at finite values of 

Having a good estimator for the learning coefficient can be extremely valuable for interpretability: this would be a parameter that captures the information-theoretic complexity of an algorithm in a very principled way that avoids serious drawbacks of previously known approaches (such as estimates of Hessian degeneracy) and can be useful for out-of-distribution detection. More generally, the Singular Learning Theory program proposes certain powerful unsupervised interpretability tools that can give information about network internals, assuming the learning coefficient (and certain related quantities) can be computed efficiently.

Modular addition as a testbed for estimating 

In Lau et al.'s paper, their SGLD-based learning coefficient estimate is applied to a tiny two-neuron network and also to an MNIST network, with promising results. We treat the modular addition network as an interesting intermediate case. Modular addition has to recommend itself the facts that: 

  • It is a mechanistically interpreted network: we know its circuits, more or less how they are implemented by neurons, and how to isolate and measure them.
  • We can cleanly distinguish networks that learn to generalize vs. networks that only memorize by looking at their circuits; moreover, we can "spoof" generalization by creating a network for learning a random commutative operation; this is a network that has the same memorization behavior as modular addition, but no possibility of generalization.
  • We can count the number of generalization circuits a network learns and reason about how different circuits interact in the loss function and in somewhat idealized free energy computations. This allows us to compare the behavior of  with respect to the number of circuits against other notions of complexity, for example, Hessian rank. 
Plot of loss on a held-out test set for a network trained on modular addition. Each step (differently colored line) corresponds to an evenly spaced checkpoint along the training. Each point corresponds to loss when all Fourier modes in the embedding weights matrix are ablated except this one. Keeping only a single important mode impacts loss much less than keeping only an unimportant mode, demonstrating the use of "grokked" Fourier modes in the embeddings matrix.

At the same time, being an algorithmically generated problem, modular addition has some important limitations from the point of view of SLT, which makes it unable to capture some of the complexity of a typical learning problem:

  • The total number of possible data points for modular addition is finite (namely, equal to  for  the prime modulus), and the target distribution is deterministic. Thus, the learning coefficient only depends on a finite number of samples, which makes the asymptotic problem slightly (but not entirely) degenerate from the point of view of statistical learning theory.
  • Even within the class of simple deterministic machine learning problems, the modular addition problem is highly symmetric; thus, it is possible for our empirical results to fail to generalize for less symmetric networks.
  • The high number of possible output tokens compared to the maximal number of samples  tokens compared to  samples, for  the modulus) may cause unusual behavior (Watanabe's results assume that the number of logits is small and the number of samples is asymptotically infinite).

Despite these limitations, we observed that (for an appropriate choice of hyperparameters) the Watanabe-Lau-Murfet-Wei estimate  gives an estimate of the learning coefficient largely compatible with theoretical predictions. In addition, the estimates behave in a remarkably consistent and stable way, which we did not expect.

Findings 

We found that, for fully trained networks, SGLD estimates using Watanabe's formula give a good approximation (up to a small factor) of the theoretical estimate for the RLCT, both for the modular addition (linear in , reasonably independent of the total number of parameters) and for the random network (quadratic in ). Moreover, it is independent of the number of atomic circuits, or "groks" (something we expect, in an appropriate limiting case, to be the case for the learning coefficient but not for other computations of effective dimension).

Diagram of the model we trained on the modular addition task. -dimensional one-hot encoded numbers are embedded in an embed_dim space. Two independent linear transformations are learned to a hidden_dim space. The two vectors are then added elementwise, passed through a GELU activation function, and then transformed back into a -dimensional vector of logits.

We also ran some "dynamical" estimates of    at unstable points along the learning trajectory of our modular addition networks. Here we observed that the  estimates closely correlate to the validation (i.e., test) despite the fact that they are computed using methods involving only the training data. In particular, these unstable measurements "notice" the grokking transition between memorization and generalization when training loss stabilizes and test loss goes down.

Scaling behavior for generalizing networks

We ran the Watanabe-Lau-Murfet-Wei -estimator algorithm on the following networks, and obtained the following results. We graph the  estimate against each prime, averaged over five experiments. 

We found that estimates using Watanabe's formula gave a good approximation (up to a small factor) of the theoretical estimate for the RLCT, both for the modular addition and for the random network: 

Plot of  estimated using SGLD for MLPs trained on modular addition mod different primes . Here, the  shown is averaged over five independent training and sampling runs. Hyperparameters for SGLD are . We can see that  scales linearly with .
 vs.  estimated for single runs of differently sized MLP networks, demonstrating similarity in RLCT across scales.
Here the different runs correspond to separately trained series of networks, demonstrating that  is consistent across models with the same architecture trained to convergence on the same dataset and task.

We observe that at a given architecture, our  estimates are very close to linear, as would be theoretically predicted. 

In principle, the minimal effective dimensionality of a model with this architecture that solves modular addition is  (this will be elaborated on in a separate theory post deriving results about modular addition networks). However, we observe that the empirical scaling factor is very close to  , double the result for a single circuit. A possible explanation for this result could be that, in the regime our models inhabit, the effective space of solutions consists of weight parameters that execute at least two simple circuits (all models we trained learned at least 4 simple circuits). 

The scaling factor of  with  is close to  This could indicate that the SGLD search procedure explores a manifold of near-minima corresponding to two grokked circuits rather than one. Note that all trained models grokked >2 circuits altogether, and varied in their number of circuits, and so we still find invariance to the total number of independent circuits learned.

When starting the experiment, we were expecting extensive differences of more than an order of magnitude between the empirical and predicted values (because of the non-ideal nature of the real-life models and limiting points in our experiments). This degree of agreement between a relatively large and messy "real-world" measurement and an ideal measurement, as well as the near-linearity here, are by no means guaranteed and updated us a significant amount towards believing that the theoretical predictions of Singular Learning Theory match well to real-world measurements. 

We also repeat the experiment at various architectures, with the number of parameters different by a relatively large factor (our largest network is more than 3 times larger than our smallest network, and our intermediate network is asymptotically twice as big as the smallest one). Larger networks do have slightly higher   , but the difference scales sub-linearly in network size, as we would expect from the true learning coefficient. 

Note that the primes we include are relatively small. While our architectures are efficient and always generalize (with close to 100% accuracy) for much larger primes, we empirically observe that the estimates for  tend to be much better and less noisy when the fully trained network is very close to convergence (0 loss). Because of computational limitations, we use a relatively large learning rate (0.01) for a relatively small number of iterations. This results in worse loss at convergence for primes above 50; we conjecture that the near-linear behavior would continue to hold for much larger primes if we used more computationally intensive methods with a smaller learning rate and a larger number of SGD steps. 

(In)dependence on the number of circuits

The networks we train sometimes learn different numbers of independent generalizing circuits embedded in different subspaces (the existence of such circuits was first proposed by Nanda et al).

We can measure the number and types of circuits learned by a network, either by considering large outlier Fourier modes in the embedding space or (more robustly) by looking for near-perfect circles in "Fourier mode-aligned" two-dimensional projections of the embedding space[4], as in the picture below

(We plan to later publish another post (on mechanistic interpretability tools for modular addition, in particular exactly distinguishing "pizza" from "clock" circuits), where these pictures will be explained more.)

We can see the number of independent Fourier mode circuits by projecting the learned embedding weights matrix to the   two-dimensional subspaces of embedding space corresponding to the different discrete Fourier modes representable in the embedding space. For example, this model for  has learned 6 Fourier modes - .

We observe in our experiments that the learning rate estimates do not seem to depend much on the number of circuits learned. For example, for the largest prime we considered, , the number of circuits learned in different runs varied between 4 and 7 circular circuits, whereas the learning coefficients for all the networks were within about 10% of each other. This result is deceptively simple but quite interesting and somewhat surprising from a theoretical viewpoint. 

For example, when measuring the effective dimension of a network via Hessian eigenvalues, a network with more than one circuit will have either effective dimension 0 (because going along a direction corresponding to any circuit counts as generalizing) or effective dimension that depends linearly on the number of circuits (because a direction counts as generalizing only if it independently generalizes each of the circuits). The fact that neither of these behaviors is observed in our context can be motivated by the Singular Learning Theory framework. Indeed, we can treat the subspace in weight space executing each circuit (or perhaps a suitable small subset of circuits) as a separate component of a singular manifold of "near-minima." As the vector spaces associated to the different circuits are in general position relative to each other, the resulting singularity is "minimally singular"[5]. This would mean that the RLCT at the singular point is equal to the RLCT along each of the individual components, which can be understood as an explanation for the observed independence result. However, we note that despite its explanatory robustness, this picture becomes more complicated when we zoom in since the loss for a multi-circuit network tends to be significantly better than the product of its parts. 

We plan to give an alternative explanation for the independence result involving a statistical model for cross-entropy loss that takes advantage of the ergodicity of multiplication modulo a prime. We flag here that we expect this independence to only hold in a "goldilocks" range of hyperparameter choices and, in particular, of the regularization constant (corresponding to the sizes of the circuits learned). A simplistic statistical model predicts at least three distinct phases here: one at a very small circuit size (corresponding to large regularization), where we expect the number of circuits to multiplicatively impact the learning rate. One at large circuit sizes (small regularization), where the learning rate estimate becomes degenerate, and one at an intermediate region, where the independence result we see is in effect.

Random operations: scaling for memorization vs. generalization

To compare our generalizing networks to networks with the same architecture, which only memorize, we ran the Watanabe-Lau-Murfet-Wei algorithm for a random commutative operation network. 

In order to get good loss for a memorization network, we need it to be overparametrized, i.e., the number of parameters needs to be above some appropriate  multiple of the total number of samples, in our case . Because the number of parameters grows linearly in , we get convergence to near zero loss only for small values of p. We note that since number-theoretic tricks like the Chinese Remainder Theorem are irrelevant for random operation networks, the values of p for this experiment do not need to be prime. Thus we run this experiment for multiples of 5 up to 40. Because of convergence issues and scaling pattern observation, we most trust our results in the short range of values between 5 and 25.

Note that this range overlaps with our list of primes only between 23 and 25; we would need to use larger networks (and probably, better learning convergence) to get reasonable values of  above this range. For the range of values we consider, we observe a larger learning coefficient with a quadratic scaling pattern in , compared to the linear linear for generalizing networks. 

 vs.  for random commutative operation.
 vs.  for random commutative operation plotted alongside modular addition results.

Remarkably, the diagram to p = 25 is almost exactly (up to a constant offset) equal to the number of memorizations, ; here 0.8 is the fraction of the full dataset used for training. We also generated data for larger multiples of 5, up to 40. Here we see clearly that the memorizing network has higher learning rate than the generalizing network at the same architecture, but the quadratic fit becomes worse for . We believe that we would recover quadratic fit for more values of p if we worked with a larger network. 

Dynamics and phase transition

Finally, we performed a dynamic estimate of the learning coefficient at various checkpoints during the learning process for generalizing networks. 

In this part of our results, we introduced some innovations to the methods of Lau et al. and Chen et al. (though we did not implement the "health-based" sampling trajectory sorting from the latter paper). Specifically, we got the best results with a temperature adjustment and with our implementation of unstable SGLD applied after restricting to the normal hyperplane to loss gradient.

A run of  estimation at 25 equally spaced checkpoints along model training. SGLD search was modified to restrict search directions to those orthogonal to the gradient at initialization.
A repeated run with another model / independent SGLD sampling, to check consistency of results.

Here we observed that the unstable  estimates closely correlate to the validation (i.e., test) despite the fact that they are computed using methods involving only the training data. In particular, these unstable measurements "notice" the grokking transition between memorization and generalization when training loss stabilizes and test loss goes down. (As our networks are quite efficient, this happens relatively early in training.)

  1. ^

    Note that Lau et al. also undertake an estimate of  for a large MNIST network with over a million neurons. Here they find that the resulting value for  is correlated with the optimization method used to train the network in a predictable direction and thus captures nontrivial information about the basin. However, the theoretical value of  is not available here, and the SGLD  algorithm fails to converge; thus, this estimate is not expected to give a faithful value of the learning coefficient in this case 

  2. ^

    Note that the dynamic  estimator attempts to apply a technique designed for stable points (i.e., local minima) to points that are not local minima and have some instability, sampling, and ergodicity issues, even with our normal-to-gradient restriction refinement. In particular, they (much more than estimates at stable points) are sensitive to hyperparameters. Thus these unstable  measurements do not currently have an associated exact theoretical value and can be thought of as an ad hoc generalization of a complexity estimate to unstable points. Nevertheless, we find that at a fixed collection of hyperparameters, these estimates give consistent results and look similar across runs, and we see that they contain nontrivial information about the loss landscape dynamics during learning.

  3. ^

    An intuition for this is that very singular loss functions (i.e., functions that have many higher-order derivatives equal to zero) are associated with very large basins, which are large enough to "fit in extra dimensions worth of parameters."

  4. ^

    The two-dimensional subspace of the embedding space  associated with the kth Fourier mode is the space spanned by the sin and cos components of the k-frequency discrete Fourier transform. Note that these spaces are not necessarily linearly independent for different modes but are independent for modes that learn a circuit.

  5. ^

     This is meant in an RLCT sense. In algebraic geometry language, a function f on weight space  is minimally singular if there exists a smooth analytic blowup  such that in local coordinates on X, f is a product of squares of coordinate functions. In this language, if we have c circuits associated to vector subspaces  in weight space, an "idealized" function with minima on k-tuples of circuits is the function 

    for  running over -element subsets and  the L2 distance from a weight to the corresponding subspace. It is easy to check that the resulting singularity is minimally singular.

New Comment
4 comments, sorted by Click to highlight new comments since: Today at 11:16 PM

To see this, we use a slight refinement of the dynamical estimator, where we restrict sampling to lie within the normal hyperplane of the gradient vector at initialization, which seems to make this behavior more robust.

 

Could you explain the intuition behind using the gradient vector at initialization? Is this based on some understanding of the global training dynamics of this particular network on this dataset?

Oh I can see how this could be confusing. We're sampling at every step in the orthogonal complement to the gradient at that step ("initialization" here refers to the beginning of sampling, i.e., we don't update the normal vector during sampling). And the reason to do this is that we're hoping to prevent the sampler from quickly leaving the unstable point and jumping into a lower-loss basin (by restricting we are guaranteeing that the unstable point is a critical point)

Oh that makes a lot of sense, yes.

I'm curious if you have guesses about how many singular dimensions were dead neurons (or neurons that are "mostly dead," only activating for a tiny fraction of the training set), versus how much the zero-gradient directions depended dynamically on training example.