Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

Co-authored by Neel Nanda and Jess Smith

Check out Concrete Steps for Getting Started in Mechanistic Interpretability for a better starting point

Why does this exist?

People often get intimidated when trying to get into AI or AI Alignment research. People often think that the gulf between where they are and where they need to be is huge. This presents practical concerns for people trying to change fields: we all have limited time and energy. And for the most part, people wildly overestimate the actual core skills required.

This guide is our take on the essential skills required to understand, write code and ideally contribute useful research to mechanistic interpretability. We hope that it’s useful and unintimidating. :) 

Core Skills:

  • Maths:
    • Linear Algebra: 3Blue1Brown or Linear Algebra Done Right
      • Core goals - to deeply & intuitively understand these concepts:
        • Basis
        • Change of basis
        • That a vector space is a geometric object that doesn’t necessarily have a canonical basis
        • That a matrix is a linear map between two vector spaces (or from a vector space to itself)
      • Bonus things that it’s useful to understand: 
        • What’s singular value decomposition? Why is it useful?
        • What are orthogonal/orthonormal matrices, and how is changing to an orthonormal basis importantly different from just any change of basis?
        • What are eigenvalues and eigenvectors, and what do these tell you about a linear map?
    • Probability basics
      • Basics of distributions: expected value, standard deviation, normal distributions
      • Log likelihood
      • Maximum value estimators
      • Random variables
      • Central limit theorem
    • Calculus basics
      • Gradients
      • The chain rule
      • The intuition for what backprop is - in particular, grokking the idea that backprop is just the chain rule on multivariate functions
  • Coding:
    • Python Basics
      • The “how to learn coding” market is pretty saturated - there’s a lot of good stuff out there! And not really a clear best one.
      • Zac Hatfield-Dodds recommends Al Sweigart's Automate the Boring Stuff and then Beyond the Basic Stuff (both readable for free on inventwithpython.com, or purchasable in books); he's also written some books of exercises. If you prefer a more traditional textbook, Think Python 2e is excellent and also available freely online.
    • NumPy Basics
  • ML:
    • Rough grounding in ML. 
    • PyTorch basics
      • Don’t go overboard here. You’ll pick up what you need over time - learning to google things when you get confused or stuck is most of the real skill in programming. 
      • One goal: build linear regression that runs in Google Colab on a GPU. 
      • The main way you will shoot yourself in the foot with PyTorch is when manipulating tensors, and especially multiplying them. I highly, highly recommend learning how to use einops (a library to nicely do any reasonable manipulation of a single tensor) and einsum (a built in torch function implementing Einstein Summation notation, to do arbitrary tensor multiplication)
        • If you try doing these things without einops and einsum you will hurt yourself. Do not recommend!
  • Transformers - probably the biggest way mechanistic interpretability differs from normal ML is that it’s really important to deeply understand the architectures of the models you use, all of the moving parts inside of them, and how they fit together. In this case, the main architecture that matters is a transformer! (This is useful in normal ML too, but you can often get away with treating the model as a black box)
    • My what is a transformer and implementing GPT-2 From Scratch video tutorials
    • A worthwhile exercise is to fill out the template notebook, accompanying the tutorial (no copying and pasting!) 
      • The notebook comes with tests so you know that your code is working, and by the end of this you'll have a working implementation of GPT-2!
      • If you can do this, I’d say that you basically fully understand the transformer architecture.
    • Alternate framings that may help give different intuitions: 
      • Nelson Elhage’s Transformers for Software Engineers (also useful to non software engineers!)
      • Check out the illustrated transformer 
        • Note that you can pretty much ignore the stuff on encoder vs decoder transformers - we mostly care about autoregressive decoder-only transformers like GPT-2, which means that each token can only see tokens before it, and they learn to predict the next token
  • Bonus: Jacob Hilton’s Deep learning for Alignment syllabus - this is a lot more content than you strictly need, but is well put together and likely a good use of time to go through at least some of!

Once you have the pre-reqs, my Getting Started in Mechanistic Interpretability guide goes into how to get further into mechanistic interpretability!

Note that there are a lot more skills in the “nice-to-haves”, but I think that generally the best way to improve at something is by getting your hard dirty and engaging with the research ideas directly, rather than making sure you learn every nice-to-have skill first - if you have the above, I think you should just jump in and start learning about the topic! Especially for the coding related skills, your focus should not be on getting your head around concepts, it should be about doing, and actually writing code and playing around with the things - the challenge of making something that actually works, and dealing with all of the unexpected practical problems that arise is the best way of really getting this.

New Comment
12 comments, sorted by Click to highlight new comments since:

For Python basics, I have to anti-recommend Shaw's 'learn the hard way'; it's generally outdated and in some places actively misleading. And why would you want to learn the hard way instead of the best way in any case?

Instead, my standard recommendation is Al Sweigart's Automate the Boring Stuff and then Beyond the Basic Stuff (both readable for free on inventwithpython.com, or purchasable in books); he's also written some books of exercises. If you prefer a more traditional textbook, Think Python 2e is excellent and also available freely online.

Thanks! I learned Python ~10 years ago and have no idea what sources are any good lol. I've edited the post with your recs :)

Strong agree with not using "the hard way." Not a big fan.

Automate the Boring Stuff isn't bad. But I think the best, direct-path/deliberate practice resource for general Python capability is Pybites (codechalleng.es) + Anki. There are 300-400 exercises (with a browser interpreter) that require you to solve problems without them handholding or instructing you (so it's like LeetCode but for actual programming capability, not just algorithm puzzles). Everything requires you to read documentation or Google effectively, and they're all decent, realistic use cases. There are also specific paths (like testing, decorators & context management, etc.) that are covered in the exercises.

My procedure:

  • Finish at least the Easy and at least half the Mediums for each relevant path.
  • If you learned something new, put it into Anki (I put the whole Pybite into a card and make a cloze-completion for the relevant code--I then have to type it for reviews).
  • Finish the rest of the Pybites (which will be an unordered mix of topics that includes the remaining Mediums and Hards for each of the learning paths, plus miscellaneous).
  • IMO, you will now actually be a solid low-intermediate Python generalist programmer, though of course you will need to learn lots of library/specialty-specific stuff in your own area.
[-]Yulu PiΩ010

hey Neel,

Great post!

I am trying to look into the code here

But the links dont work anymore! It would be nice if you could help update them!

I dont know if this link works for the original content: https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/clean-transformer-demo/Clean_Transformer_Demo_Template.ipynb

 

Thanks a lot!

Ah, thanks! Haven't looked at this point in a while, updated it a bit. I've since made my own transformer tutorial which (in my extremely biased opinion) is better esp for interpretability. It comes with a template notebook to fill out alongside part 2, (with tests!) and by the end you'll have implemented your own GPT-2.

More generally, my getting started in mech interp guide is a better place to start than this guide, and has more on transformers!

Thanks for putting this together Neel, I think you achieved your goal of making it fairly unintimidating. 

One quick note: all of the links in this section are outdated. Perhaps you can update them.

Good (but hard) exercise: Code your own tiny GPT-2 and train it. If you can do this, I’d say that you basically fully understand the transformer architecture.

[-]micΩ110

Thanks for writing this! Here is a quick explanation of all the math concepts – mostly written by ChatGPT with some manual edits.

A basis for a vector space is a set of linearly independent vectors that can be used to represent any vector in the space as a linear combination of those basis vectors. For example, in two-dimensional Euclidean space, the standard basis is the set of vectors (1, 0) and (0, 1), which are called the "basis vectors."

A change of basis is the process of expressing a vector in one basis in terms of another basis. For example, if we have a vector v in two-dimensional Euclidean space and we want to express it in terms of the standard basis, we can write v as a linear combination of (1, 0) and (0, 1). Alternatively, we could choose a different basis for the space, such as the basis formed by the vectors (4, 2) and (3, 5). In this case, we would express v in terms of this new basis by writing it as a linear combination of (4, 2) and (3, 5).

A vector space is a set of vectors that can be added together and multiplied ("scaled") by numbers, called scalars. Scalars are often taken to be real numbers, but there are also vector spaces with scalar multiplication by complex numbers, rational numbers, or generally any field. The operations of vector addition and scalar multiplication must satisfy certain requirements, called axioms. Examples of vector spaces include the set of all two-dimensional vectors (i.e., the set of all points in two-dimensional Euclidean space), the set of all polynomials with real coefficients, and the set of all continuous functions from a given set to the real numbers. A vector space can be thought of as a geometric object, but it does not necessarily have a canonical basis, meaning that there is not a preferred set of basis vectors that can be used to represent all the vectors in the space.

A matrix is a rectangular array of numbers, symbols, or expressions, arranged in rows and columns. A matrix is a linear map between two vector spaces, or from a vector space to itself, because it can take any vector in the original vector space and transform it into a new vector in the target vector space using a set of linear equations. Each column of the matrix represents one of the new basis vectors, which are used to define the transformation. In the expression , we take each element of the original vector and multiply it by the corresponding element in the appropriate column of the matrix, and then add these products together to create the new vector.

The singular value decomposition (SVD) is a factorization of a matrix M into the product of three matrices: , where U and V are orthogonal matrices and S is a diagonal matrix with non-negative real numbers on the diagonal, called the "singular values" of M. The SVD is a useful tool for understanding the properties of a matrix and for solving certain types of linear systems. It can also be used for data compression, image processing, and other applications.

An orthogonal matrix (or orthonormal matrix) is a square matrix whose columns and rows are mutually orthonormal (i.e., they are orthogonal and have unit length). Orthogonal matrices have the property that their inverse is equal to their transpose.

Changing to an orthonormal basis can be importantly different from just any change of basis because it has certain computational advantages. For example, when working with an orthonormal basis, the inner product of two vectors can be computed simply as the sum of the products of their corresponding components, without the need to use any weights or scaling factors. This can make certain calculations, such as finding the length of a vector or the angle between two vectors, simpler and more efficient.

Eigenvalues and eigenvectors are special types of scalars and vectors that are associated with a linear map or a matrix. If M is a linear map or matrix and v is a non-zero vector, then v is an eigenvector of M if there exists a scalar λ, called an eigenvalue, such that . In other words, when a vector is multiplied by the matrix M, the resulting vector is a scalar multiple of the original vector. Eigenvalues and eigenvectors are important because they provide insight into the properties of the linear map or matrix. For example, the eigenvalues of a matrix can tell us whether it is singular (i.e., not invertible) or whether it is diagonalizable (i.e., can be expressed in the form , where P is a matrix and D is a diagonal matrix). The eigenvectors of a matrix can also be used to determine its rank, nullity, and other characteristics.

Probability basics: Probability is a measure of the likelihood of an event occurring. It is typically represented as a number between 0 and 1, where 0 indicates that the event is impossible and 1 indicates that the event is certain to occur. The probability of an event occurring can be calculated by counting the number of ways in which the event can occur, divided by the total number of possible outcomes.

Basics of distributions: A distribution is a function that describes the probability of a random variable taking on different values. The expected value of a distribution is a measure of the center of the distribution, and it is calculated as the weighted average of the possible values of the random variable, where the weights are the probabilities of each value occurring. The standard deviation is a measure of the dispersion of the distribution, and it is calculated as the square root of the variance, which is the expected value of the squared deviation of a random variable from its mean. A normal distribution (or Gaussian distribution) is a continuous probability distribution with a bell-shaped curve, which is defined by its mean and standard deviation.

Log likelihood: The log likelihood of a statistical model is a measure of how well the model fits a given set of data. It is calculated as the logarithm of the probability of the data given the model, and it is often used to compare the relative fit of different models.

Maximum value estimators: A maximum value estimator is a statistical method that is used to estimate the value of a parameter that maximizes a given objective function. Examples of maximum value estimators include the maximum likelihood estimator and the maximum a posteriori estimator.

  • The maximum likelihood estimator is a method for estimating the parameters of a statistical model based on the principle that the parameters that maximize the likelihood of the data are the most likely to have generated the data.
  • The maximum a posteriori (MAP) estimator is a method for estimating the parameters of a statistical model based on the principle that the parameters that maximize the posterior probability of the data are the most likely to have generated the data. The posterior probability is the probability of the data given the model and the prior knowledge about the parameters. The MAP estimator is often used in Bayesian inference, and it is a popular method for estimating the parameters of a model in the presence of prior knowledge.

Random variables: A random variable is a variable whose value is determined by the outcome of a random event. For example, the toss of a coin is a random event, and the number of heads that result from a series of coin tosses is a random variable.

Central limit theorem: The central limit theorem is a statistical theorem that states that, as the sample size of a random variable increases, the distribution of the sample means approaches a normal distribution, regardless of the distribution of the underlying random variable.

Calculus basics: Calculus is a branch of mathematics that deals with the study of rates of change and the accumulation of quantities. It is a fundamental tool in the study of functions and is used to model and solve problems in a variety of fields, including physics, engineering, and economics.

Gradients: In calculus, the gradient of a (scalar-valued multivariate differentiable) function is a vector that describes the direction in which the function is increasing most quickly. It is calculated as the partial derivative of the function with respect to each variable.

The chain rule: The chain rule is a fundamental rule of calculus that allows us to calculate the derivative of a composite function. It states that if f is a function of g, and g is a function of x, then the derivative of f with respect to x is equal to the derivative of f with respect to g times the derivative of g with respect to x. In tohers words, (df / dx) = (df / dg) * (dg / dx).

On backpropagation:

Backpropagation is an algorithm for training artificial neural networks, which are machine learning models inspired by the structure and function of the brain. It is used to adjust the weights and biases of the network in order to minimize the error between the predicted output and the desired output of the network.

The idea behind backpropagation is that, given a multivariate function that describes the relationships between the input variables and the output variables of a neural network, we can use the chain rule to calculate the gradient of the function with respect to the weights and biases of the network. The gradient tells us how the error changes as we adjust the weights and biases, and we can use this information to update the weights and biases in a way that reduces the error.

To understand why backpropagation is just the chain rule on multivariate functions, it's helpful to consider the structure of a neural network. A neural network consists of layers of interconnected nodes, each of which performs a calculation based on the inputs it receives from the previous layer. The output of the network is a function of the inputs, and the weights and biases of the network determine how the inputs are transformed as they pass through the layers of the network.

The process of backpropagation involves starting at the output layer of the network and working backwards through the layers, using the chain rule to calculate the gradients of the weights and biases at each layer. This is done by calculating the derivative of the error with respect to the output of each layer, and then using the chain rule to propagate these derivatives back through the layers of the network. This allows us to calculate the gradients of the weights and biases at each layer, which we can use to update the weights and biases in a way that minimizes the error.

Overall, backpropagation is an efficient and effective way to train neural networks because it allows us to calculate the gradients of the weights and biases efficiently, using the chain rule to propagate the derivatives through the layers of the network. This enables us to adjust the weights and biases in a way that minimizes the error, which is essential for the effective operation of the network.

Thanks for writing this!

Have you considered crossposting to the EA Forum (although the post was mentioned here)?

how is changing to an orthonormal basis importantly different from just any change of basis?

What exactly do you have in mind here?

Oh, just that it preserves distance/L2 norm/angles/orthogonality. I find that this is often an important intuition, since a bunch of operations in transformers depend on orthogonality/norm. In particular, norm is useful for thinking about weight decay and as rough heuristics for eg info transferred.

Probably most important is that whenever a layer tries to read a feature from the residual stream, it's basically projecting the residual stream onto a single dimension, which hopefully corresponds to that feature. And, importantly, ignoring all orthogonal dimensions - the projection operation is invariant under rotations (orthonormal change of bases) but NOT under arbitrary change of bases)