Ophiology (or, how the Mamba architecture works)

13Neel Nanda

1Chakshu Mira

1Chakshu Mira

1Adrià Garriga-alonso

1Chakshu Mira

1Danielle Ensign

1Chakshu Mira

1Danielle Ensign

New Comment

Thanks for the clear explanation, Mamba is more cursed and less Transformer like than I realised! And thanks for creating and open sourcing Mamba Lens, it looks like a very useful tool for anyone wanting to build on this stuff

## Discretize B ## # [B,N] [E->N] [B,E] B = layer.W_B(x[b,l]) # no bias

Shouldn't this be x[:,l] instead of x[b,l]?

The following post was made as part of Danielle's MATS work on doing circuit-based mech interp on Mamba, mentored by Adrià Garriga-Alonso. It's the first in a sequence of posts about finding an IOI circuit in Mamba/applying ACDC to Mamba.

This introductory post was also made in collaboration with Gonçalo Paulo.

## A new challenger arrives!

Why Mamba?

## Promising Scaling

Mamba

^{[1]}is a type of recurrent neural network based on state-space models, and is being proposed as an alternative architecture to transformers. It is the result of years of capability research^{[2]}^{[3]}^{[4]}and likely not the final iteration of architectures based on state-space models.In its current form, Mamba has been scaled up to 2.8B parameters on The Pile and on Slimpj, having similar scaling laws when compared to Llama-like architectures.

Scaling curves from Mamba paper: Mamba scaling compared to Llama (Transformer++), previous state space models (S3++), convolutions (Hyena), and a transformer inspired RNN (RWKV)

More recently, ai21labs

^{[5]}trained a 52B parameter MOE Mamba-Transformer hybrid called Jamba. At inference, this model has 12B active parameters and has benchmark scores comparable to Llama-2 70B and Mixtral.Jamba benchmark scores, from Jamba paper

^{[5:1]}## Efficient Inference

One advantage of RNNs, and in particular of Mamba, is that the memory required to store the context length is constant, as you only need to store the past state of the SSM and of the convolution layers, while it grows linearly for transformers. The same happens with the generation time, where predicting each token scales as O(1) instead of O(context length).

Jamba throughput (tokens/second), from Jamba paper

^{[5:2]}## What are State-space models?

The inspiration for Mamba (and similar models) is an established technique used in control theory called state space models (SSM). SSMs are normally used to represent linear systems that have p inputs, q outputs and n state variables. To keep the notation concise, we will consider the input as E-dimensional vector x(t)∈RE, an E-dimensional output y(t)∈RE and a N-dimensional latent space h∈RN. In the following, we will note the dimensions of new variables using the notation [X,Y]. In particular, in Mamba 2.8b, E=5120 and N=16.

Specifically, we have the following:

[N]˙h(t)=[N,N]A[N]h(t)+[N,E]B[E]x(t) [E]y(t)=[E,N]C[N]h(t)+[E,E]D[E]x(t)

This is an ordinary differential equation (ODE), where ˙h(t) is the derivative of h(t) with respect to time, t. This ODE can be solved in various ways, which will be described below.

In state space models, A is called the

state matrix, B is called theinput matrix, C is called theoutput matrix, and D is called thefeedthrough matrix.## Solving the ODE

We can write the ODE from above as a recurrence, using discrete timesteps:

[N]ht=[N,N]¯¯¯¯A[N]ht−1+[N,E]¯¯¯¯B[E]xt [E]yt=[E,N]C[N]ht+[E,E]D[E]xt

where ¯¯¯¯A and ¯¯¯¯B are our

discretization matrices. Different ways of integrating the original ODE will give different ¯¯¯¯A and ¯¯¯¯B, but will still preserve this overall form.In the above, t corresponds to discrete time. In language modeling, t refers to the token position.

## Euler method

The simplest way to numerically integrate an ODE is by using the Euler method, which consists in approximating the derivative by considering the ratio between a small variation in h and a small variation in time, ˙h=dhdt≈ΔhΔt. This allows us to write:

ht+1−htΔt=Aht+Bxt ht+1=Δt(Aht+Bxt)+ht

Where the index t, of ht, represents the discretized time. This is the similar to when considering a character's position and velocity in a video game, for instance. If a character has a velocity v and a position x0, to find the position after Δt time we can do x1=Δtv+x0. In general:

xt=Δtvt+xt−1

Turning back to the above example, we can rewrite

ht+1=Δt(Aht+Bxt)+ht

as

ht=(ΔA+I)ht−1+ΔBxt

which means that, for the Euler Method, ¯¯¯¯A=(ΔA+I) and ¯¯¯¯B=ΔB.

Here, Δ is an abbreviation of Δt, the discretization size in time.

## Zero-Order Hold (ZOH)

Another way to integrate the ODE is to consider that the input x(t) remains fixed during a time interval Δ, and to integrate the differential equation from time t to t+Δ. This gives us an expression for x(t+Δ):

x(t+Δ)=eΔAx(t)+u(t+1)∫t+Δte(t+Δ−τ)ABdτ

With some algebra we finally get:

¯¯¯¯A=exp(ΔA)¯¯¯¯B=(ΔA)−1(exp(ΔA)−I)ΔB

## Discretization rule used in Mamba

Mamba uses a mix of Zero-Order Hold and the Euler Method:

¯¯¯¯A=exp(ΔA)¯¯¯¯B=ΔB

Why is this justified? Consider the ZOH ¯¯¯¯B:

¯¯¯¯B=(ΔA)−1(exp(ΔA)−I)ΔB

In Mamba, A is diagonal, as we will see later, so we can write

((ΔA)−1(exp(ΔA)−I))i,i=exp(ΔAi,i)−1ΔAi,i

If we consider that ΔAi,i is small and we expand the exponential to just first order

^{[6]}, this expression reduces to 1 which means that:¯¯¯¯B=(ΔA)−1(exp(ΔA)−I)ΔB≈ΔB

for small enough ΔAi,i. Using the same approximation for ¯¯¯¯A recovers the Euler method:

¯¯¯¯A=exp(ΔA)≈I+ΔA

In the original work, the authors argued that while ZOH was necessary for the modeling of ¯¯¯¯A, using the Euler Method for ¯¯¯¯B gave reasonable results, without having to compute (ΔA)−1.

## Specific Quirks to Mamba

## The structured SSM

Mamba takes an interesting approach to the SSM equation. As previously mentioned, each timestep in Mamba represents a token position, and each token is represented (by the time it arrives to the SSM) by a E dimensional vector. The authors chose to represent the SSM as:

[E,N]ht=[E,N]¯¯¯¯A[E,N]ht−1+[E,N]¯¯¯¯B[E]xt

[E]yt=[N]C[E,N]ht+[E]D[E]xt

## The case of a 1-Dimensional input

When trying to understand Mamba, I find it's easiest to start with each xt being a single value first, and then working up from there. The standard SSM equation is, then:

[N]ht=[N,N]¯¯¯¯A[N]ht−1+[N,1]¯¯¯¯B[1]xt [1]yt=[1,N]C[N]ht+[1,1]D[1]xt

The authors of the original Mamba paper were working on top of previous results on Structured SSMs. Because of this, in this work, A is a diagonal matrix. This means that A can be represented as a set of N numbers instead of a NxN matrix. That gives us:

[N]ht=[N]¯¯¯¯A[N]ht−1+[N,1][1]xt [1]yt=[1,N]C[N]ht+[1,1]D[1]xt

Where [N]¯¯¯¯A[N]ht−1 is an element-wise product. In this example we are mapping a 1-dimensional input to a n-dimensional hidden state, then mapping the n-dimensional hidden state back to a 1 dimensional output.

## The Mamba implementation

In practice, xt and yt are not one dimensional, but E-dimensional vectors. Mamba simply maps each of these elements separately to a N dimensional hidden space. So we can write a set of E equations:

[N]ht,e=[N]¯¯¯¯A[N]ht−1,e+[N,1]¯¯¯¯B[1]xt,e [1]yt,e=[1,N]C[N]ht,e+[1,1]D[1]xt,e

Where e ranges from [1,E]. This means that each dimension of input to the SSM block is modeled by its own, independent, SSM. We will see that, due to the selection mechanism (see below) Δ,¯¯¯¯A,¯¯¯¯B,C are a function of all the dimensions of the input, not just the dimension e.

One thing to note: In practice, A has a separate value for each e, and is encoded as an [E,N] matrix. We can denote ¯¯¯¯¯¯Ae=ΔAe as the N-sized entry for stream e, giving us,

[N]ht,e=[N]¯¯¯¯¯¯Ae[N]ht−1,e+[N,1]¯¯¯¯B[1]xt,e

## Selection mechanism

Mamba deviates from the simplest SSM approaches, and from the previous work of the authors, by making matrices B and C dependent on the input, x(t). Not only that, but the time discretization Δ is also input dependent. This replaces the equations shown above, with one which takes the form:

[N]ht,e=[N]¯¯¯¯¯¯¯¯¯At,e[N]ht−1,e+[N,1]¯¯¯¯¯¯¯¯¯Bt,e[1]xt,e [1]yt,e=[1,N]Ct[N]ht,e+[1,1]D[1]xt,e

Where the new matrices are given by:

[N]¯¯¯¯¯¯¯¯¯At,e=exp([1]Δt,e[N]Ae)

[N]¯¯¯¯¯¯¯¯¯Bt,e=[1]Δt,e[N]Bt,with[N]Bt=[N,E]WB[E]xt

[N]Ct=[N,E]WC[E]xt

[1]Δt,e=softplus([E]xt⋅[E]WΔe+[1]BΔe)

with [E,E]WΔ,[E]BΔ,[N,E]WB,[N,E]WC being learned parameters, and softplus(x)=log(1+ex)

softplus

One final thing to note: A is not a trainable parameter, and what is actually trained is [E,N]Alog. A is then computed as A=−exp(Alog) (using element-wise exp). This ensures A is a strictly negative number. Because Δ is always postitive, this ensures that the first term of SSM can be seen as how much of the previous state is kept at a given token position, while the second term is related to how much it is written to the state.

In turn, this implies that exp(ΔA) is between 0 and 1. This is important for stable training: it ensures that the elements of h(t) do not grow exponentially with token position t, and the gradients do not explode. It is long known

^{[7]}that the explosion and vanishing of gradients are obstacles to training RNNs, and successful architectures (LSTM, GRU) minimize these.## WΔ is low rank

In Mamba, they don't encode [E,E]WΔ as an [E,E] matrix. Instead, it is encoded as two smaller matrices:

[E,E]WΔ=[E,DΔ]WΔ1[DΔ,E]WΔ2

Where, for example, E=2048, DΔ=64

This makes this term

[1]Δt,e=softplus([E]xt⋅[E]WΔ[:,e]+[1]BΔ[e])

Be instead

[1]Δt,e=softplus([1][E]xt⋅[E]([E,DΔ]WΔ1[DΔ]WΔ2[:,e])+[1]BΔ[e])

## RMSNorm

This normalization is not unique to Mamba. It's defined as

RMSNorm([B,L,D]x)=x√mean(x2,dim=-1)weight

If mean was instead sum, this first term would be normalizing x along the D dimension. Because it's mean there's an extra D term, and we can rewrite this as:

RMSNorm([B,L,D]x)=√Dx√sum(x2,dim=-1)weight

The reason we want to do this is so that each

element's value is on average 1, as opposed to the whole activation's vector. Since the introduction of the He initialization^{[8]}, deep learning weights have been initialized so the activation variance is 1 assuming the input variance is 1, thus keeping gradients stable throughout training.## Full Architecture

Now that we know how the SSM works, here is the full architecture.

## Dimensions

(Example values from state-spaces/mamba-370m)

## Notes on reading these graphs

## Overview

Mamba has:

High level overview of Mamba

## Layer contents

Each layer does:

Mamba layer overview

silu

## SSM

From above:

[1]Δt,e=softplus([E]xt⋅([E,DΔ]WΔ1[DΔ,1]WΔ2[:,e].view(DΔ,1)).view(E)+[1]BΔ[e])

[E,N]A=−exp([E,N]Alog)

[N]¯¯¯¯¯¯¯¯¯At,e=exp([1]Δt,e[N]Ae)

[N]Bt=[N,E]WB[E]xt

[N]¯¯¯¯¯¯¯¯¯Bt,e=[1]Δt,e[N]Bt

[N]Ct=[N,E]WC[E]xt

[N]ht,e=[N]¯¯¯¯¯¯¯¯¯At,e[N]ht−1,e+[N,1]¯¯¯¯¯¯¯¯¯Bt,e[1]xt,e [1]yt,e=[1,N]Ct[N]ht,e+[1,1]D[1]xt,e

where [E,E]WΔ,[E]BΔ,[N,E]Alog,[N,E]WB,[N,E]WC are learned parameters, and softplus(x)=log(1+ex)

Or, vectorized, and computing non-h terms ahead of time (since they don't depend on the recurrence)

Selective SSM

Also keep in mind: In the official implementation, WΔ2 is called dtproj, and some matrices are concatenated together (this is numerically equivalent, but helps performance as it's a fused operation):

## Further reading

## Appendix

Here's some further info on how Mamba's 1D conv works, for those unfamiliar. This is not unique to Mamba, conv is a standard operation usually used in image processing.

## Conv1D Explanation

The basic unit of a Conv1D is applying a kernel to a sequence.

For example, say my kernel is

`[-1,2,3]`

and my sequence is`[4,5,6,7,8,9]`

.Then to apply that kernel, I move it across my sequence like this:

So our resulting vector would be

`[24, 28, 32, 36]`

It's annoying that our output is smaller than our input, so we can pad our input first:

`[0,0,4,5,6,7,8,9,0,0]`

Now we get

So our result is

`[12, 23, 24, 28, 32, 36, 10, -9]`

Now this is longer than we need, so we'll cut off the last two, giving us

`[12, 23, 24, 28, 32, 36]`

## Worked Conv Example

Mamba conv is defined as

In this example, I will set:

In practice,

`D_conv=4`

and`E`

is around`2048-5012`

.Our input to to mamba's conv1d is of size [B, E, L]. I'll do a single batch.

Because

`groups = E = 5`

, we have`5`

filters:Let our context be:

Represented as embedding vectors

First we pad

Now to apply our first filter, we grab the first element of every vector

Giving us

Now we apply

`filter 0 [ 0.4, 0.7, -2.1, 1.1]`

with bias`[0.2]`

So our output of

`filter 0`

isNow we cut off the last two (to give us same size output as L), giving us

For

`filter 1`

, we grab the second elementGiving us

Now we apply

`filter 1 [ 0.1, -0.7, -0.3, 0.0]`

with bias`[0.2]`

etc.

## Conv1D in code

Here's what that means in code:

Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces, 2023. https://arxiv.org/abs/2312.00752 ↩︎

Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher Re. Hippo: Recurrent memory with optimal polynomial projections, 2020. https://arxiv.org/abs/2008.07669 ↩︎

Albert Gu, Karan Goel, and Christopher Re. Efficiently modeling long sequences with structured state spaces, 2022. https://arxiv.org/abs/2111.00396 ↩︎

Daniel Y. Fu, Tri Dao, Khaled K. Saab, Armin W. Thomas, Atri Rudra, and Christopher R ́e. Hungry hungry hippos: Towards language modeling with state space models, 2023. https://arxiv.org/abs/2212.14052 ↩︎

Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Gloz-man, Michael Gokhman, Avashalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, and Yoav Shoham. Jamba: A hybrid transformer-mamba language model, 2024. https://arxiv.org/abs/2403.19887 ↩︎ ↩︎ ↩︎

The Taylor series expansion of exp(x) at x=0 is exp(x)=1+x+x22+x36+... And if we just consider the first-order terms, then we get exp(x)≈1+x ↩︎

Pascanu, Razvan, Tomas Mikolov, and Yoshua Bengio. "On the difficulty of training recurrent neural networks." International Conference on Machine Learning, 2013. https://arxiv.org/abs/1211.5063 ↩︎

He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1026-1034. 2015. https://arxiv.org/abs/1502.01852 ↩︎