Sparse Autoencoders (SAEs) have become a central tool in mechanistic interpretability research, providing a way to decompose a model's internal activations into sparse, interpretable features. However, extracting these features often requires running the SAE over large volumes of activations across many layers and tokens. This makes SAE inference efficiency a practical bottleneck for interpretability research at scale.
This post focuses on improving the inference efficiency of JumpReLU Sparse Autoencoders, which were introduced by DeepMind in Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders (Rajamanoharan et al). Instead of using a traditional ReLU activation function, these SAEs use JumpReLU, which zeros out activations that fall below a learned per-feature threshold . This gives JumpReLU SAEs a variable number of active features per token (commonly written as , the count of nonzero activations), unlike TopK SAEs which fire exactly features per token.
I use the terms "fire" and "fired" to describe features with non-zero activations.
Traditional JumpReLU SAE implementations compute the decoder step as a dense matrix multiplication (feature_acts @ W_dec), but this is wasteful because of the sparsity of feature_acts. Instead, you can exploit this sparsity property and skip the zero entries entirely during matrix multiplication with a custom Triton kernel.
Intuition: Sparsity Should Be Free
When a single token passes through a JumpReLU SAE with 65,536 features, the encoder produces a feature activation vector of length 65,536, but only some entries are nonzero.
To be more concrete, consider a toy SAE with feature activations . Now suppose that we only have 2 active features, where represents the weight matrix of the decoder layer:
We then compute the output with:
Notice how only two of the rows of the decoder matrix were actually used in the computation. The rest were multiplied by 0 and contributed nothing to the output. We could instead just compute:
Now imagine this same example but increase the hidden dimension from 8 to something much greater. For instance with 72 active features. That would mean you're multiplying ~99.89% of rows by zero.
If we knew in advance which features are nonzero and their corresponding values, we could skip these zero multiplications and simplify the computation.
For a single token, this can be divided into two parts:
First, we find the nonzero entries of the hidden/sparse token representation and the corresponding indices of those entries.
We then use those indices and values to directly look up and scale the corresponding rows of , then sum the results.
Implementation Overview
When implementing this kernel, my first thought was to begin with a preliminary step that figures out exactly how many features fired for each token so the system could then allocate exactly the memory needed to store the CSR representation (more on that later). However, this process involves a GPU->CPU sync, which causes some slowdown.
As an alternative, you can instead allocate some predetermined/fixed amount of memory for each token using a configurable max_l0 parameter. This speeds up computation but overallocates memory and introduces an important caveat that max_l0 must be large enough to avoid errors. For example if you set max_l0=10, but one of the tokens in the batch has >10 nonzero features, those extra features will be dropped, resulting in information loss.
Both approaches are covered below. For convenience, let's refer to the kernel that allocates exactly the memory needed for the CSR representation as the Exact Allocation kernel and the kernel that allocates a predetermined amount of memory per token as the Fixed Allocation kernel. The Fixed Allocation kernel can also be configured with either validate=True or validate=False. The validate=True version is slightly slower than validate=False, but it raises an error if any token fires more features than max_l0. This is clearly safer, but if you are 100% sure that no token will exceed max_l0, then you can use validate=False for some speedup.
Exact Allocation Kernel
To skip zero entries during matrix multiplication, we need to first represent the feature activations in Compressed Sparse Row (CSR) format, which is a standard way of representing sparse matrices that stores only the nonzero values and their indices. For the example above, instead of storing all 8 entries of , CSR stores just:
To allocate enough memory for building a CSR representation, we need to know how much memory each token requires (how many features fired per token). A count_nonzero kernel handles this:
import triton import triton.language as tl
@triton.jit def count_nonzero(feature_acts_ptr, counts_ptr, n_features, BLOCK_F: tl.constexpr): pid_token = tl.program_id(0) # Which token am I working on? (row index) pid_d = tl.program_id(1) # Which chunk of features am I working on? (column index)
# Compute the feature indices this block is responsible for feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F) mask = feat_offsets < n_features # Guard against reading past the end of the feature dimension
# Navigate to this token's features in memory, then to this block's chunk feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets vals = tl.load(feat_ptrs, mask=mask, other=0.0) # Load the feature values
fired = vals != 0.0 # Which features in this chunk are active (nonzero)? fired_count = tl.sum(fired.to(tl.int32)) # How many active (nonzero) features in this chunk?
# Accumulate into this token's count (atomic since multiple blocks write to the same token) tl.atomic_add(counts_ptr + pid_token, fired_count)
If you're unfamiliar with Triton, the key mental model is that rather than writing a loop that runs sequentially, you write a kernel that describes what one block does and Triton launches many of these blocks in parallel across the GPU. In this kernel, each block is responsible for a chunk of one token's features. The two program_id calls tell each block where it is: pid_token identifies which token (which row of the input matrix), and pid_d identifies which chunk of that token's features to process.
Also note that pointers in GPU kernels point to the start of a flat block of memory. To reach a particular token's features, we offset into that memory by pid_token * n_features. Within that token, we offset further by pid_d * BLOCK_F to reach the right chunk. The mask guards against reading past the end when n_features isn't a clean multiple of BLOCK_F.
Finally, since multiple blocks may be counting features for the same token simultaneously, tl.atomic_add ensures their partial counts are combined safely.
This count_nonzero kernel produces an array counts of length where is the number of tokens in the batch. The number of active (nonzero) features for the token is stored in counts[i].
We can then use this information to allocate two flat arrays, flat_idx and flat_val, which hold the active feature indices and their values across the entire batch. For example, this might look like:
You may have noticed that it's not clear which entries belong to which token. For example, flat_idx[2] tells us that the feature at index fired, but it doesn't tell us if this was for the first token in the batch or the second token or the third, etc.
We can solve this problem by introducing a new array row_offsets of length , where row_offsets[b] stores the starting index in flat_idx/flat_val where token 's entries begin. It's computed by taking a cumulative sum of counts, so each token's region starts exactly where the previous one ends. For example, if three tokens have 2, 5, and 3 active features:
Now token 0's entries live at indices 0–1, token 1's at 2–6, token 2's at 7–9, and the final entry (10) tells us the total number of nonzero features across all tokens in the batch.
We can construct row_offsets inside a wrapper function build_csr that also handles memory and orchestration. It calls compute_csr_kernel, which is the kernel responsible for actually filling flat_idx and flat_val with the correct values. Note that flat_idx and flat_val are initialized as empty arrays as pre-allocated storage that compute_csr_kernel will write into.
# Count how many features fired per token counts = torch.zeros(B, dtype=torch.int32, device=device) grid = (B, triton.cdiv(n_features, BLOCK_F)) count_nonzero[grid](feature_acts, counts, n_features, BLOCK_F=BLOCK_F)
# Cumsum over counts gives each token a contiguous region in the flat arrays # row_offsets[b] = start index of token b's entries in flat_idx/flat_val row_offsets = torch.zeros(B + 1, dtype=torch.int32, device=device) row_offsets[1:] = counts.cumsum(0).to(torch.int32)
# The last entry is the total number of nonzeros. This is used to size the flat arrays total_nnz = int(row_offsets[-1].item()) # GPU->CPU sync point
# write_pos is a per-token cursor that coordinates concurrent writes within # a token's region. Each block atomically claims the next available slots by # bumping write_pos by its count, getting back its starting offset (base). write_pos = torch.zeros(B, dtype=torch.int32, device=device)
# Where does this token's region start in flat_idx/flat_val? region_start = tl.load(row_offsets_ptr + pid_token)
# Atomically claim the next block_count slots within this token's region block_count = tl.sum(fired_int) base = tl.atomic_add(write_pos_ptr + pid_token, block_count)
# Assign each active feature a unique slot within the claimed range local_rank = tl.cumsum(fired_int) - fired_int slots = region_start + base + local_rank
# Write the feature index and value into the claimed slots tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=fired & mask) tl.store(flat_val_ptr + slots, vals, mask=fired & mask)
Next, sparse_decode_kernel uses this CSR structure to carry out the matrix multiplication step. For each token, it looks up where that token's active features live in flat_idx/flat_val using row_offsets, then loops over them, accumulating the weighted sum of the corresponding decoder rows into a tile of the output.
@triton.jit def sparse_decode_kernel( flat_idx_ptr, flat_val_ptr, row_offsets_ptr, W_dec_ptr, out_ptr, d_model, BLOCK_D: tl.constexpr, ): pid_token = tl.program_id(0) pid_d = tl.program_id(1) # Find the slice of flat_idx/flat_val belonging to this token start = tl.load(row_offsets_ptr + pid_token) end = tl.load(row_offsets_ptr + pid_token + 1) n = end - start # Number of active features for this token # This block owns a BLOCK_D-wide slice of the output row offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) mask = offsets < d_model acc = tl.zeros([BLOCK_D], dtype=tl.float32) # Loop over this token's active features, accumulating their contribution for i in range(n): j = start + i feat_idx = tl.load(flat_idx_ptr + j) # Which decoder row? feat_val = tl.load(flat_val_ptr + j) # Scale factor # Load the corresponding decoder row (just this block's slice) row_ptrs = W_dec_ptr + (feat_idx * d_model) + offsets row = tl.load(row_ptrs, mask=mask, other=0.0) acc += feat_val.to(tl.float32) * row.to(tl.float32) # Write this block's output slice tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)
Finally, we put all of these kernels together by wrapping them in a single sparse_decode() function that acts as a drop-in replacement for @:
def _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B, BLOCK_D=256): d_model = W_dec.shape[1] out = torch.zeros((B, d_model), device=W_dec.device, dtype=torch.float32) # parallelize over batch rows and d_model tiles grid = (B, triton.cdiv(d_model, BLOCK_D)) sparse_decode_kernel[grid]( flat_idx, flat_val, row_offsets, W_dec, out, d_model, BLOCK_D=BLOCK_D ) return out def sparse_decode(feature_acts, W_dec): # Triton requires contiguous memory for correct stride arithmetic W_dec = W_dec.contiguous() flat_idx, flat_val, row_offsets, B = build_csr(feature_acts) return _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B)
Fixed Allocation Kernel
Recall how in the Exact Allocation Kernel, inside build_csr we extracted the total number of nonzero entries across all tokens by retrieving the last entry of row_offsets:
total_nnz = int(row_offsets[-1].item())
When we call .item(), we are forcing the CPU to wait for the GPU to finish the counting pass before it can read total_nnz and allocate flat_idx/flat_val.
Normally the CPU queues up GPU work asynchronously and moves on without waiting, but .item() breaks that pipeline by requiring the CPU to stall until the GPU result is ready.
This turns out to be a significant source of slowdown.
The Fixed Allocation kernel works around this by not even allocating exactly the memory needed in the first place (meaning we don't even need total_nnz). Instead, we allocate max_l0 slots per token, where max_l0 is a user-specified upper bound on how many features can fire for any single token. This also means we no longer need to count the number of nonzero tokens before computing the CSR structure.
With these changes, the new build_csr wrapper function looks like:
def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024, max_l0: int = 512, validate: bool = True): B, n_features = feature_acts.shape device = feature_acts.device # Fixed memory allocation capacity = B * max_l0 flat_idx = torch.empty(capacity, dtype=torch.int32, device=device) flat_val = torch.empty(capacity, dtype=feature_acts.dtype, device=device) # write_pos serves as both the per-token write cursor during the kernel # and the per-token count afterward write_pos = torch.zeros(B, dtype=torch.int32, device=device) grid = (B, triton.cdiv(n_features, BLOCK_F)) compute_csr_kernel[grid]( feature_acts, write_pos, flat_idx, flat_val, n_features, max_l0, BLOCK_F=BLOCK_F, ) counts = write_pos # final cursor value = number of features written per token # Optional safety check. This reintroduces a GPU→CPU sync but catches silent truncation if validate and counts.max().item() > max_l0: raise ValueError( f"A token fired more than max_l0={max_l0} features " f"(max was {counts.max().item()}). Increase max_l0." ) return flat_idx, flat_val, counts, B, max_l0
As mentioned briefly earlier, if a token fires more features than max_l0, those extra features are silently dropped by the overflow guard in the kernel. This can be dangerous because the result is wrong but there's no crash. The validate=True default catches this by checking counts.max() after the kernel, at the cost of reintroducing a GPU→CPU sync. (However this is still faster than Exact Allocation in practice.) If you're very confident that your max_l0 is a safe upper bound for your SAE then you can pass validate=False to skip the check, but this is not recommended.
The kernel to compute CSR changes minimally. We no longer need row_offsets since we know that each token takes up max_l0 entries in memory, so the lookup for the start of a token's region is replaced by region_start = pid_token * max_l0.
@triton.jit def compute_csr_kernel( feature_acts_ptr, write_pos_ptr, flat_idx_ptr, flat_val_ptr, n_features, max_l0, BLOCK_F: tl.constexpr, ): pid_token = tl.program_id(0) pid_d = tl.program_id(1) # Navigate to this block's chunk feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F) mask = feat_offsets < n_features feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets vals = tl.load(feat_ptrs, mask=mask, other=0.0) fired = vals != 0.0 fired_int = fired.to(tl.int32) # Each token owns a fixed region of max_l0 slots region_start = pid_token * max_l0 # Atomically claim the next available slots within this token's region block_count = tl.sum(fired_int) base = tl.atomic_add(write_pos_ptr + pid_token, block_count) # Assign each active feature a unique slot within the claimed range local_rank = tl.cumsum(fired_int) - fired_int local_slot = base + local_rank # Guard against writing past this token's region if L0 exceeds max_l0 in_region = local_slot < max_l0 write_mask = fired & mask & in_region slots = region_start + local_slot tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=write_mask) tl.store(flat_val_ptr + slots, vals, mask=write_mask)
The decoder kernel then changes in the same way. row_offsets is no longer needed, and counts replaces the start/end bracket:
@triton.jit def sparse_decode_kernel( flat_idx_ptr, flat_val_ptr, counts_ptr, W_dec_ptr, out_ptr, d_model, max_l0, BLOCK_D: tl.constexpr, ): pid_token = tl.program_id(0) pid_d = tl.program_id(1) start = pid_token * max_l0 n = tl.load(counts_ptr + pid_token) # Actual number of active features for this token offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) mask = offsets < d_model acc = tl.zeros([BLOCK_D], dtype=tl.float32) # Same loop as before for i in range(n): j = start + i feat_idx = tl.load(flat_idx_ptr + j) feat_val = tl.load(flat_val_ptr + j) row_ptrs = W_dec_ptr + feat_idx * d_model + offsets row = tl.load(row_ptrs, mask=mask, other=0.0) acc += feat_val.to(tl.float32) * row.to(tl.float32) tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)
Benchmarks
Writing custom GPU kernels is great, but it's important to make sure that they're actually making the computation faster. I used triton.testing.do_bench (warmup=25, rep=100, reporting the median) to time these kernels and compared them against dense matrix multiplication (feature_acts @ W_dec). All tests were run on a NVIDIA GeForce RTX 4090 GPU.
As a quick summary, the table below shows the relative speedups for an example input configuration (B = 32, n_features = 65536, d_model = 768, L0 = 64):
Method
Full matmul pipeline (ms)
Speedup vs dense
Dense cuBLAS
0.288
1.0×
torch.compile
0.288
1.0×
torch.sparse.mm + .to_sparse_csr()
0.210
1.4×
Custom — exact allocation
0.151
1.9×
Custom — fixed allocation (validate=False)
0.041
7.0×
Custom — fixed allocation (validate=True)
0.115
2.5×
Correctness
First, I verified that the custom kernels actually perform matrix multiplication correctly (a custom kernel that is faster but gives the wrong answer doesn't help anyone). In other words, we verify that sparse_decode(feature_acts, W_dec) == feature_acts @ W_dec across 486 different inputs using combinations of the parameters below. Note that sparse_decode() here is just a wrapper matmul function that uses our custom Triton kernels under the hood.
Axis
Meaning
Values tested
Count
version
kernel implementation
exact, fixed
2
dtype
input dtype of feature_acts/ W_dec
float32, float16, bfloat16
3
B
batch size (tokens)
1, 4, 32
3
n_features
SAE dictionary width
256, 1024, 16384
3
d_model
output width
128, 512, 768
3
L0
features fired per token
1, 8, 100
3
Total: 2 × 3 × 3 × 3 × 3 × 3 = 486 configurations. Each asserts output is fp32 and matches the dense fp32 reference within atol=1e-4, rtol=1e-3.
Decoder Kernel Speed (CSR Excluded)
The preprocessing step of computing a CSR representation adds some computational overhead. It would be interesting to see a direct comparison between sparse_decode_kernel and dense matrix multiplication if you didn't have to pay for that overhead (assume that you somehow already have access to a CSR representation).
If you hold some parameters of the input constant (B=32, n_features=65536, d_model=768) while varying L0 (the number of fired features) as shown in the table below, then how much faster is sparse_decode_kernel?
Note that this is EXCLUDING the overhead of the CSR preprocessing step (i.e., compute_csr_kernel). Also note that sparse_decode_kernel is essentially the same between Exact Allocation and Fixed Allocation so there is no need to differentiate, but for completeness the graph below plots both (they overlap).
Sparsity
Kernel speedup vs dense
16
0.02%
25.5×
32
0.05%
18.7×
64
0.10%
12.8×
128
0.20%
8.0×
256
0.39%
5.0×
512
0.78%
3.0×
1024
1.56%
1.7×
4096
6.25%
0.6×
We can also vary n_features while keeping constant B=32, L0=64, d_model=768:
n_features
Kernel speedup vs dense
4,096
1.5×
16,384
4.1×
32,768
7.3×
65,536
12.8×
131,072
22.5×
Full Pipeline Speed
So clearly sparse_decode_kernel alone is faster than dense matrix multiplication at high sparsity. But of course in practice we probably need to compute CSR as well, which will slow things down somewhat.
The table below shows the relative speedups (relative to dense matmul) for three different input configurations. Here "Kernel only" refers to only sparse_decode_kernel (CSR is precomputed), while "Full" refers to the whole pipeline (i.e., build_csr).
Configuration
Kernel only
Full (exact)
Full (fixed, no val.)
Full (fixed, val.)
B=32, F=65536, D=768, L0=64
12.8×
1.9×
7.0×
2.5×
B=256, F=65536, D=768, L0=64
7.7×
1.7×
3.1×
2.2×
B=32, F=131072, D=512, L0=128
22.5×
2.2×
6.1×
2.3×
The graph below shows the speed of the full pipeline (Exact Allocation) and decode-only as you vary sparsity. Here, L0 sweeps over [16, 32, 64, 128, 256, 512, 1024, 4096, 16384] while holding B=32, n_features=65536, and d_model=768 constant.
Additional Baselines
To be comprehensive, we can also compare our custom kernels to torch.sparse.mm (using PyTorch's to_sparse_csr()), which uses cuSPARSE internally, and torch.compile. This focuses on the same three input configurations as above.
Note: I found it a little suspicious that this custom kernel would "beat" torch.sparse.mm. It turns out this is mostly because of beating to_sparse_csr() when building the CSR. There doesn't seem to be much of a difference in speed between the custom kernel and cuSPARSE on the matrix multiplication step alone.
As expected, torch.compile doesn't provide a noticeable speedup, but I wanted to include it anyway for completeness.
End-to-End on Real SAEs
Up until now we have been focusing entirely on the speed of the matrix multiplication operation, but at the end of the day we care about SAE inference speed as a whole. This is benchmarked by replacing only the decoder matmul step in a SAELens JumpReLU SAE forward pass. The table below focuses on five SAEs across two model families and three dictionary sizes.
SAE
F
D
L0
Max diff
Exact
Fixed (val.)
Fixed (no val.)
Gemma Scope 2B, L20, 65k
65,536
2,304
72
3.8e-6
4.27×
5.57×
11.41×
Gemma Scope 9B, L20, 65k
65,536
3,584
72
3.8e-6
5.66×
7.34×
13.27×
Gemma Scope 2B, L12, 65k
65,536
2,304
72
9.5e-7
3.91×
5.48×
11.33×
Gemma Scope 2B, L12, 262k
262,144
2,304
100
1.9e-6
12.08×
14.49×
22.59×
Qwen Scope 3.5 2B, L12
32,768
2,048
100
4.8e-7
1.98×
2.54×
5.74×
Memory Overhead
The purpose of the Fixed Allocation kernel was to overallocate memory in exchange for speed, so it would be helpful to see exactly how much more memory it uses compared to the Exact Allocation kernel. Surprisingly, it turns out that in practice this overhead is small:
B
max_l0
Dense (MB)
Exact (MB)
Fixed (MB)
Overhead vs exact
32
512
218.3
218.4
218.5
+0.1 MB
256
512
277.7
277.9
278.8
+0.9 MB
1024
512
482.3
482.9
485.6
+2.7 MB
1024
1024
482.3
482.9
490.7
+7.8 MB
Limitations
While these results are encouraging, there are a few important limitations to be aware of and gaps that I plan to address as I continue working on this project.
First, the above benchmark numbers are not absolute, as these tests were run in a specific environment (WSL2 with GPU clocks not pinned). The primary goal of these benchmarks was to gauge the relative performance of the custom kernels compared to baseline implementations. The actual absolute speed likely differs depending on the hardware and benchmarking setup.
A second limitation, which was discussed earlier but is worth reiterating, is that although the Fixed Allocation kernel with validate=False achieves the highest performance, it can silently produce incorrect results if the max_l0 parameter is set too low. For this reason using either the Exact Allocation kernel or Fixed Allocation with validate=True is likely better for most cases.
Thirdly, these kernels were designed specifically for sparse matrix multiplication, meaning that beyond a certain sparsity threshold, dense matrix multiplication is actually faster.
Fourth, this implementation focuses exclusively on the decoder inference step of JumpReLU Sparse Autoencoders, but there are likely other sources of inefficiency that could be addressed. For example, future projects could focus on the encoder pass or support for training through custom backward kernels. Additionally the current implementation only supports float32 outputs.
Finally, all experiments were run on an RTX 4090, and performance may differ on other GPU architectures such as the A100 or H100.
Conclusion + Link to Code
In conclusion, this project implements custom Triton kernels for the decoder inference step of JumpReLU SAEs by exploiting the inherent sparsity of the hidden representation. On a sample of real SAEs, this achieves 2.5–14× speedup with the Fixed Allocation (validate=True) kernel, with larger gains at higher dictionary sizes.
I welcome feedback! If you have thoughts, questions, or find any issues, feel free to leave a comment or reach out directly. This is also my first GPU kernel project, so if you're experienced with Triton or GPU kernel optimization and see things I could have done better, I would appreciate any suggestions.
Motivation
Sparse Autoencoders (SAEs) have become a central tool in mechanistic interpretability research, providing a way to decompose a model's internal activations into sparse, interpretable features. However, extracting these features often requires running the SAE over large volumes of activations across many layers and tokens. This makes SAE inference efficiency a practical bottleneck for interpretability research at scale.
This post focuses on improving the inference efficiency of JumpReLU Sparse Autoencoders, which were introduced by DeepMind in Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders (Rajamanoharan et al). Instead of using a traditional ReLU activation function, these SAEs use JumpReLU, which zeros out activations that fall below a learned per-feature threshold . This gives JumpReLU SAEs a variable number of active features per token (commonly written as , the count of nonzero activations), unlike TopK SAEs which fire exactly features per token.
I use the terms "fire" and "fired" to describe features with non-zero activations.
Traditional JumpReLU SAE implementations compute the decoder step as a dense matrix multiplication (
feature_acts @ W_dec), but this is wasteful because of the sparsity offeature_acts. Instead, you can exploit this sparsity property and skip the zero entries entirely during matrix multiplication with a custom Triton kernel.Intuition: Sparsity Should Be Free
When a single token passes through a JumpReLU SAE with 65,536 features, the encoder produces a feature activation vector of length 65,536, but only some entries are nonzero.
To be more concrete, consider a toy SAE with feature activations . Now suppose that we only have 2 active features, where represents the weight matrix of the decoder layer:
We then compute the output with:
Notice how only two of the rows of the decoder matrix were actually used in the computation. The rest were multiplied by 0 and contributed nothing to the output. We could instead just compute:
Now imagine this same example but increase the hidden dimension from 8 to something much greater. For instance with 72 active features. That would mean you're multiplying ~99.89% of rows by zero.
If we knew in advance which features are nonzero and their corresponding values, we could skip these zero multiplications and simplify the computation.
For a single token, this can be divided into two parts:
Implementation Overview
When implementing this kernel, my first thought was to begin with a preliminary step that figures out exactly how many features fired for each token so the system could then allocate exactly the memory needed to store the CSR representation (more on that later). However, this process involves a GPU->CPU sync, which causes some slowdown.
As an alternative, you can instead allocate some predetermined/fixed amount of memory for each token using a configurable
max_l0parameter. This speeds up computation but overallocates memory and introduces an important caveat thatmax_l0must be large enough to avoid errors. For example if you setmax_l0=10, but one of the tokens in the batch has >10 nonzero features, those extra features will be dropped, resulting in information loss.Both approaches are covered below. For convenience, let's refer to the kernel that allocates exactly the memory needed for the CSR representation as the Exact Allocation kernel and the kernel that allocates a predetermined amount of memory per token as the Fixed Allocation kernel. The Fixed Allocation kernel can also be configured with either
validate=Trueorvalidate=False. Thevalidate=Trueversion is slightly slower thanvalidate=False, but it raises an error if any token fires more features thanmax_l0. This is clearly safer, but if you are 100% sure that no token will exceedmax_l0, then you can usevalidate=Falsefor some speedup.Exact Allocation Kernel
To skip zero entries during matrix multiplication, we need to first represent the feature activations in Compressed Sparse Row (CSR) format, which is a standard way of representing sparse matrices that stores only the nonzero values and their indices. For the example above, instead of storing all 8 entries of , CSR stores just:
To allocate enough memory for building a CSR representation, we need to know how much memory each token requires (how many features fired per token). A
count_nonzerokernel handles this:import triton
import triton.language as tl
@triton.jit
def count_nonzero(feature_acts_ptr, counts_ptr, n_features, BLOCK_F: tl.constexpr):
pid_token = tl.program_id(0) # Which token am I working on? (row index)
pid_d = tl.program_id(1) # Which chunk of features am I working on? (column index)
# Compute the feature indices this block is responsible for
feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)
mask = feat_offsets < n_features # Guard against reading past the end of the feature dimension
# Navigate to this token's features in memory, then to this block's chunk
feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets
vals = tl.load(feat_ptrs, mask=mask, other=0.0) # Load the feature values
fired = vals != 0.0 # Which features in this chunk are active (nonzero)?
fired_count = tl.sum(fired.to(tl.int32)) # How many active (nonzero) features in this chunk?
# Accumulate into this token's count (atomic since multiple blocks write to the same token)
tl.atomic_add(counts_ptr + pid_token, fired_count)
If you're unfamiliar with Triton, the key mental model is that rather than writing a loop that runs sequentially, you write a kernel that describes what one block does and Triton launches many of these blocks in parallel across the GPU. In this kernel, each block is responsible for a chunk of one token's features. The two
program_idcalls tell each block where it is:pid_tokenidentifies which token (which row of the input matrix), andpid_didentifies which chunk of that token's features to process.Also note that pointers in GPU kernels point to the start of a flat block of memory. To reach a particular token's features, we offset into that memory by
pid_token * n_features. Within that token, we offset further bypid_d * BLOCK_Fto reach the right chunk. Themaskguards against reading past the end whenn_featuresisn't a clean multiple ofBLOCK_F.Finally, since multiple blocks may be counting features for the same token simultaneously,
tl.atomic_addensures their partial counts are combined safely.This where is the number of tokens in the batch. The number of active (nonzero) features for the token is stored in
count_nonzerokernel produces an arraycountsof lengthcounts[i].We can then use this information to allocate two flat arrays,
flat_idxandflat_val, which hold the active feature indices and their values across the entire batch. For example, this might look like:You may have noticed that it's not clear which entries belong to which token. For example, fired, but it doesn't tell us if this was for the first token in the batch or the second token or the third, etc.
flat_idx[2]tells us that the feature at indexWe can solve this problem by introducing a new array , where 's entries begin. It's computed by taking a cumulative sum of
row_offsetsof lengthrow_offsets[b]stores the starting index inflat_idx/flat_valwhere tokencounts, so each token's region starts exactly where the previous one ends. For example, if three tokens have 2, 5, and 3 active features:Now token 0's entries live at indices 0–1, token 1's at 2–6, token 2's at 7–9, and the final entry (10) tells us the total number of nonzero features across all tokens in the batch.
We can construct
row_offsetsinside a wrapper functionbuild_csrthat also handles memory and orchestration. It callscompute_csr_kernel, which is the kernel responsible for actually fillingflat_idxandflat_valwith the correct values. Note thatflat_idxandflat_valare initialized as empty arrays as pre-allocated storage thatcompute_csr_kernelwill write into.def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024):
B, n_features = feature_acts.shape
device = feature_acts.device
# Count how many features fired per token
counts = torch.zeros(B, dtype=torch.int32, device=device)
grid = (B, triton.cdiv(n_features, BLOCK_F))
count_nonzero[grid](feature_acts, counts, n_features, BLOCK_F=BLOCK_F)
# Cumsum over counts gives each token a contiguous region in the flat arrays
# row_offsets[b] = start index of token b's entries in flat_idx/flat_val
row_offsets = torch.zeros(B + 1, dtype=torch.int32, device=device)
row_offsets[1:] = counts.cumsum(0).to(torch.int32)
# The last entry is the total number of nonzeros. This is used to size the flat arrays
total_nnz = int(row_offsets[-1].item()) # GPU->CPU sync point
flat_idx = torch.empty(total_nnz, dtype=torch.int32, device=device)
flat_val = torch.empty(total_nnz, dtype=feature_acts.dtype, device=device)
# write_pos is a per-token cursor that coordinates concurrent writes within
# a token's region. Each block atomically claims the next available slots by
# bumping write_pos by its count, getting back its starting offset (base).
write_pos = torch.zeros(B, dtype=torch.int32, device=device)
compute_csr_kernel[grid](
feature_acts,
row_offsets,
write_pos,
flat_idx,
flat_val,
n_features,
BLOCK_F=BLOCK_F,
)
return flat_idx, flat_val, row_offsets, B
@triton.jit
def compute_csr_kernel(
feature_acts_ptr,
row_offsets_ptr,
write_pos_ptr,
flat_idx_ptr,
flat_val_ptr,
n_features,
BLOCK_F: tl.constexpr,
):
pid_token = tl.program_id(0)
pid_d = tl.program_id(1)
# Same pointer arithmetic as count_nonzero, navigate to this block's chunk
feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)
mask = feat_offsets < n_features
feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets
vals = tl.load(feat_ptrs, mask=mask, other=0.0)
fired = vals != 0.0
fired_int = fired.to(tl.int32)
# Where does this token's region start in flat_idx/flat_val?
region_start = tl.load(row_offsets_ptr + pid_token)
# Atomically claim the next block_count slots within this token's region
block_count = tl.sum(fired_int)
base = tl.atomic_add(write_pos_ptr + pid_token, block_count)
# Assign each active feature a unique slot within the claimed range
local_rank = tl.cumsum(fired_int) - fired_int
slots = region_start + base + local_rank
# Write the feature index and value into the claimed slots
tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=fired & mask)
tl.store(flat_val_ptr + slots, vals, mask=fired & mask)
Next,
sparse_decode_kerneluses this CSR structure to carry out the matrix multiplication step. For each token, it looks up where that token's active features live inflat_idx/flat_valusingrow_offsets, then loops over them, accumulating the weighted sum of the corresponding decoder rows into a tile of the output.@triton.jit
def sparse_decode_kernel(
flat_idx_ptr, flat_val_ptr, row_offsets_ptr,
W_dec_ptr, out_ptr, d_model,
BLOCK_D: tl.constexpr,
):
pid_token = tl.program_id(0)
pid_d = tl.program_id(1)
# Find the slice of flat_idx/flat_val belonging to this token
start = tl.load(row_offsets_ptr + pid_token)
end = tl.load(row_offsets_ptr + pid_token + 1)
n = end - start # Number of active features for this token
# This block owns a BLOCK_D-wide slice of the output row
offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
mask = offsets < d_model
acc = tl.zeros([BLOCK_D], dtype=tl.float32)
# Loop over this token's active features, accumulating their contribution
for i in range(n):
j = start + i
feat_idx = tl.load(flat_idx_ptr + j) # Which decoder row?
feat_val = tl.load(flat_val_ptr + j) # Scale factor
# Load the corresponding decoder row (just this block's slice)
row_ptrs = W_dec_ptr + (feat_idx * d_model) + offsets
row = tl.load(row_ptrs, mask=mask, other=0.0)
acc += feat_val.to(tl.float32) * row.to(tl.float32)
# Write this block's output slice
tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)
Finally, we put all of these kernels together by wrapping them in a single
sparse_decode()function that acts as a drop-in replacement for@:def _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B, BLOCK_D=256):
d_model = W_dec.shape[1]
out = torch.zeros((B, d_model), device=W_dec.device, dtype=torch.float32)
# parallelize over batch rows and d_model tiles
grid = (B, triton.cdiv(d_model, BLOCK_D))
sparse_decode_kernel[grid](
flat_idx, flat_val, row_offsets, W_dec, out, d_model, BLOCK_D=BLOCK_D
)
return out
def sparse_decode(feature_acts, W_dec):
# Triton requires contiguous memory for correct stride arithmetic
W_dec = W_dec.contiguous()
flat_idx, flat_val, row_offsets, B = build_csr(feature_acts)
return _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B)
Fixed Allocation Kernel
Recall how in the Exact Allocation Kernel, inside
build_csrwe extracted the total number of nonzero entries across all tokens by retrieving the last entry ofrow_offsets:total_nnz = int(row_offsets[-1].item())When we call
.item(), we are forcing the CPU to wait for the GPU to finish the counting pass before it can readtotal_nnzand allocateflat_idx/flat_val.Normally the CPU queues up GPU work asynchronously and moves on without waiting, but
.item()breaks that pipeline by requiring the CPU to stall until the GPU result is ready.This turns out to be a significant source of slowdown.
The Fixed Allocation kernel works around this by not even allocating exactly the memory needed in the first place (meaning we don't even need
total_nnz). Instead, we allocatemax_l0slots per token, wheremax_l0is a user-specified upper bound on how many features can fire for any single token. This also means we no longer need to count the number of nonzero tokens before computing the CSR structure.With these changes, the new
build_csrwrapper function looks like:def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024, max_l0: int = 512, validate: bool = True):
B, n_features = feature_acts.shape
device = feature_acts.device
# Fixed memory allocation
capacity = B * max_l0
flat_idx = torch.empty(capacity, dtype=torch.int32, device=device)
flat_val = torch.empty(capacity, dtype=feature_acts.dtype, device=device)
# write_pos serves as both the per-token write cursor during the kernel
# and the per-token count afterward
write_pos = torch.zeros(B, dtype=torch.int32, device=device)
grid = (B, triton.cdiv(n_features, BLOCK_F))
compute_csr_kernel[grid](
feature_acts, write_pos, flat_idx, flat_val,
n_features, max_l0, BLOCK_F=BLOCK_F,
)
counts = write_pos # final cursor value = number of features written per token
# Optional safety check. This reintroduces a GPU→CPU sync but catches silent truncation
if validate and counts.max().item() > max_l0:
raise ValueError(
f"A token fired more than max_l0={max_l0} features "
f"(max was {counts.max().item()}). Increase max_l0."
)
return flat_idx, flat_val, counts, B, max_l0
As mentioned briefly earlier, if a token fires more features than
max_l0, those extra features are silently dropped by the overflow guard in the kernel. This can be dangerous because the result is wrong but there's no crash. Thevalidate=Truedefault catches this by checkingcounts.max()after the kernel, at the cost of reintroducing a GPU→CPU sync. (However this is still faster than Exact Allocation in practice.) If you're very confident that yourmax_l0is a safe upper bound for your SAE then you can passvalidate=Falseto skip the check, but this is not recommended.The kernel to compute CSR changes minimally. We no longer need
row_offsetssince we know that each token takes upmax_l0entries in memory, so the lookup for the start of a token's region is replaced byregion_start = pid_token * max_l0.@triton.jit
def compute_csr_kernel(
feature_acts_ptr,
write_pos_ptr,
flat_idx_ptr,
flat_val_ptr,
n_features,
max_l0,
BLOCK_F: tl.constexpr,
):
pid_token = tl.program_id(0)
pid_d = tl.program_id(1)
# Navigate to this block's chunk
feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)
mask = feat_offsets < n_features
feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets
vals = tl.load(feat_ptrs, mask=mask, other=0.0)
fired = vals != 0.0
fired_int = fired.to(tl.int32)
# Each token owns a fixed region of max_l0 slots
region_start = pid_token * max_l0
# Atomically claim the next available slots within this token's region
block_count = tl.sum(fired_int)
base = tl.atomic_add(write_pos_ptr + pid_token, block_count)
# Assign each active feature a unique slot within the claimed range
local_rank = tl.cumsum(fired_int) - fired_int
local_slot = base + local_rank
# Guard against writing past this token's region if L0 exceeds max_l0
in_region = local_slot < max_l0
write_mask = fired & mask & in_region
slots = region_start + local_slot
tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=write_mask)
tl.store(flat_val_ptr + slots, vals, mask=write_mask)
The decoder kernel then changes in the same way.
row_offsetsis no longer needed, andcountsreplaces the start/end bracket:@triton.jit
def sparse_decode_kernel(
flat_idx_ptr,
flat_val_ptr,
counts_ptr,
W_dec_ptr,
out_ptr,
d_model,
max_l0,
BLOCK_D: tl.constexpr,
):
pid_token = tl.program_id(0)
pid_d = tl.program_id(1)
start = pid_token * max_l0
n = tl.load(counts_ptr + pid_token) # Actual number of active features for this token
offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
mask = offsets < d_model
acc = tl.zeros([BLOCK_D], dtype=tl.float32)
# Same loop as before
for i in range(n):
j = start + i
feat_idx = tl.load(flat_idx_ptr + j)
feat_val = tl.load(flat_val_ptr + j)
row_ptrs = W_dec_ptr + feat_idx * d_model + offsets
row = tl.load(row_ptrs, mask=mask, other=0.0)
acc += feat_val.to(tl.float32) * row.to(tl.float32)
tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)
Benchmarks
Writing custom GPU kernels is great, but it's important to make sure that they're actually making the computation faster. I used
triton.testing.do_bench(warmup=25, rep=100, reporting the median) to time these kernels and compared them against dense matrix multiplication (feature_acts @ W_dec). All tests were run on a NVIDIA GeForce RTX 4090 GPU.As a quick summary, the table below shows the relative speedups for an example input configuration (
B = 32,n_features = 65536,d_model = 768,L0 = 64):Method
Full matmul pipeline (ms)
Speedup vs dense
Dense cuBLAS
0.288
1.0×
torch.compile0.288
1.0×
torch.sparse.mm+.to_sparse_csr()0.210
1.4×
Custom — exact allocation
0.151
1.9×
Custom — fixed allocation (
validate=False)0.041
7.0×
Custom — fixed allocation (
validate=True)0.115
2.5×
Correctness
First, I verified that the custom kernels actually perform matrix multiplication correctly (a custom kernel that is faster but gives the wrong answer doesn't help anyone). In other words, we verify that
sparse_decode(feature_acts, W_dec) == feature_acts @ W_decacross 486 different inputs using combinations of the parameters below. Note thatsparse_decode()here is just a wrapper matmul function that uses our custom Triton kernels under the hood.Axis
Meaning
Values tested
Count
versionkernel implementation
exact,fixed2
dtypeinput dtype of
feature_acts/W_decfloat32,float16,bfloat163
Bbatch size (tokens)
1, 4, 32
3
n_featuresSAE dictionary width
256, 1024, 16384
3
d_modeloutput width
128, 512, 768
3
L0features fired per token
1, 8, 100
3
Total: 2 × 3 × 3 × 3 × 3 × 3 = 486 configurations. Each asserts output is fp32 and matches the dense fp32 reference within
atol=1e-4, rtol=1e-3.Decoder Kernel Speed (CSR Excluded)
The preprocessing step of computing a CSR representation adds some computational overhead. It would be interesting to see a direct comparison between
sparse_decode_kerneland dense matrix multiplication if you didn't have to pay for that overhead (assume that you somehow already have access to a CSR representation).If you hold some parameters of the input constant (
B=32,n_features=65536,d_model=768) while varyingL0(the number of fired features) as shown in the table below, then how much faster issparse_decode_kernel?Note that this is EXCLUDING the overhead of the CSR preprocessing step (i.e.,
compute_csr_kernel). Also note thatsparse_decode_kernelis essentially the same between Exact Allocation and Fixed Allocation so there is no need to differentiate, but for completeness the graph below plots both (they overlap).Sparsity
Kernel speedup vs dense
16
0.02%
25.5×
32
0.05%
18.7×
64
0.10%
12.8×
128
0.20%
8.0×
256
0.39%
5.0×
512
0.78%
3.0×
1024
1.56%
1.7×
4096
6.25%
0.6×
We can also vary
n_featureswhile keeping constantB=32,L0=64,d_model=768:n_featuresKernel speedup vs dense
4,096
1.5×
16,384
4.1×
32,768
7.3×
65,536
12.8×
131,072
22.5×
Full Pipeline Speed
So clearly
sparse_decode_kernelalone is faster than dense matrix multiplication at high sparsity. But of course in practice we probably need to compute CSR as well, which will slow things down somewhat.The table below shows the relative speedups (relative to dense matmul) for three different input configurations. Here "Kernel only" refers to only
sparse_decode_kernel(CSR is precomputed), while "Full" refers to the whole pipeline (i.e.,build_csr).Configuration
Kernel only
Full (exact)
Full (fixed, no val.)
Full (fixed, val.)
B=32, F=65536, D=768, L0=64
12.8×
1.9×
7.0×
2.5×
B=256, F=65536, D=768, L0=64
7.7×
1.7×
3.1×
2.2×
B=32, F=131072, D=512, L0=128
22.5×
2.2×
6.1×
2.3×
The graph below shows the speed of the full pipeline (Exact Allocation) and decode-only as you vary sparsity. Here,
L0sweeps over [16, 32, 64, 128, 256, 512, 1024, 4096, 16384] while holdingB=32,n_features=65536, andd_model=768constant.Additional Baselines
To be comprehensive, we can also compare our custom kernels to
torch.sparse.mm(using PyTorch'sto_sparse_csr()), which uses cuSPARSE internally, andtorch.compile. This focuses on the same three input configurations as above.Note: I found it a little suspicious that this custom kernel would "beat"
torch.sparse.mm. It turns out this is mostly because of beatingto_sparse_csr()when building the CSR. There doesn't seem to be much of a difference in speed between the custom kernel and cuSPARSE on the matrix multiplication step alone.As expected,
torch.compiledoesn't provide a noticeable speedup, but I wanted to include it anyway for completeness.End-to-End on Real SAEs
Up until now we have been focusing entirely on the speed of the matrix multiplication operation, but at the end of the day we care about SAE inference speed as a whole. This is benchmarked by replacing only the decoder matmul step in a
SAELensJumpReLU SAE forward pass. The table below focuses on five SAEs across two model families and three dictionary sizes.SAE
F
D
L0
Max diff
Exact
Fixed (val.)
Fixed (no val.)
Gemma Scope 2B, L20, 65k
65,536
2,304
72
3.8e-6
4.27×
5.57×
11.41×
Gemma Scope 9B, L20, 65k
65,536
3,584
72
3.8e-6
5.66×
7.34×
13.27×
Gemma Scope 2B, L12, 65k
65,536
2,304
72
9.5e-7
3.91×
5.48×
11.33×
Gemma Scope 2B, L12, 262k
262,144
2,304
100
1.9e-6
12.08×
14.49×
22.59×
Qwen Scope 3.5 2B, L12
32,768
2,048
100
4.8e-7
1.98×
2.54×
5.74×
Memory Overhead
The purpose of the Fixed Allocation kernel was to overallocate memory in exchange for speed, so it would be helpful to see exactly how much more memory it uses compared to the Exact Allocation kernel. Surprisingly, it turns out that in practice this overhead is small:
B
max_l0Dense (MB)
Exact (MB)
Fixed (MB)
Overhead vs exact
32
512
218.3
218.4
218.5
+0.1 MB
256
512
277.7
277.9
278.8
+0.9 MB
1024
512
482.3
482.9
485.6
+2.7 MB
1024
1024
482.3
482.9
490.7
+7.8 MB
Limitations
While these results are encouraging, there are a few important limitations to be aware of and gaps that I plan to address as I continue working on this project.
First, the above benchmark numbers are not absolute, as these tests were run in a specific environment (WSL2 with GPU clocks not pinned). The primary goal of these benchmarks was to gauge the relative performance of the custom kernels compared to baseline implementations. The actual absolute speed likely differs depending on the hardware and benchmarking setup.
A second limitation, which was discussed earlier but is worth reiterating, is that although the Fixed Allocation kernel with
validate=Falseachieves the highest performance, it can silently produce incorrect results if themax_l0parameter is set too low. For this reason using either the Exact Allocation kernel or Fixed Allocation withvalidate=Trueis likely better for most cases.Thirdly, these kernels were designed specifically for sparse matrix multiplication, meaning that beyond a certain sparsity threshold, dense matrix multiplication is actually faster.
Fourth, this implementation focuses exclusively on the decoder inference step of JumpReLU Sparse Autoencoders, but there are likely other sources of inefficiency that could be addressed. For example, future projects could focus on the encoder pass or support for training through custom backward kernels. Additionally the current implementation only supports
float32outputs.Finally, all experiments were run on an RTX 4090, and performance may differ on other GPU architectures such as the A100 or H100.
Conclusion + Link to Code
In conclusion, this project implements custom Triton kernels for the decoder inference step of JumpReLU SAEs by exploiting the inherent sparsity of the hidden representation. On a sample of real SAEs, this achieves 2.5–14× speedup with the Fixed Allocation (validate=True) kernel, with larger gains at higher dictionary sizes.
The full implementation is available on GitHub.
I welcome feedback! If you have thoughts, questions, or find any issues, feel free to leave a comment or reach out directly. This is also my first GPU kernel project, so if you're experienced with Triton or GPU kernel optimization and see things I could have done better, I would appreciate any suggestions.