A self-contained walkthrough of the key idea from "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (Katharopoulos et al., ICML 2020, arXiv:2006.16236)
This article is based on a 2020 paper by Katharopoulos et al. which shows that by replacing the softmax in standard attention with a kernel dot product, you can reduce the complexity of self-attention from to , and that this automatically makes the transformer equivalent to a recurrent neural network at inference time.
The paper is clean, but it mixes a lot of background and experimental detail with the core idea. This article cuts the fluff, gets straight to the core concept and math, and explains every variable as it appears. Nothing is left undefined.
1. What Standard Attention Does
Let be a sequence of feature vectors, each of dimension . A transformer layer projects this input into three matrices:
where and are learned projection matrices. is the key/query dimension and is the value dimension.
For a general non-negative similarity function , the attention output at position is:
denotes the -th row of (a -dimensional vector), same for and . The output is a weighted average of all value vectors.
For softmax attention, the similarity is:
The full attention matrix has shape . Computing and storing it costs in both time and memory. This is the bottleneck.
2. The Key Idea: Factorize sim with a Feature Map
Instead of computing similarity directly, express it as an inner product of explicit feature representations:
where is a feature map that produces non-negative outputs, and is the feature map output dimension.
Substituting into equation (1):
Why does this matter? The two bracketed terms in equation (3) are the same for every query :
- — compute once
- — compute once
Then for each query, apply via a dot product. Total cost: instead of .
3. Adding Causal Masking (For Autoregressive Use)
For autoregressive generation, position must only attend to positions . Equation (1) becomes:
Define two cumulative state quantities:
Equation (4) then simplifies to:
and update in from the previous step:
The full causal attention pass therefore costs in time, linear in .
4. This Is an RNN
The recurrence above is, by definition, a recurrent neural network. The full transformer layer (including the per-position feedforward sublayer ) can be written as:
Variable reference:
- : input at timestep , shape
- : learned projection matrices (same as before)
- : attention memory state, shape
- : normalizer memory state, shape
- : per-position feedforward sublayer (e.g. a two-layer MLP)
- : output at timestep
- : feature map applied to the key and query projections
At training time, the cumulative sums can be parallelized (like a prefix scan), so GPU efficiency is preserved. At inference time, you maintain and as a fixed-size state and update one token at a time in per step. No growing KV cache.
5. The Feature Map
For the kernel decomposition in equation (2) to work, must produce non-negative outputs. The paper uses:
where is the exponential linear unit:
Adding 1 shifts the output range to always be positive. ELU is preferred over ReLU because ReLU sets gradients to zero for negative inputs, which can stall training.
This feature map cannot exactly recover softmax attention (the exact feature map for the exponential kernel is infinite-dimensional). But empirically, it converges to similar performance.
6. Complexity Summary
| Method | Time | Memory | Inference per step |
|---|---|---|---|
| Softmax attention | at step | ||
| Linear attention |
= feature map dimension, = key/query dimension, = value dimension, = sequence length.
Closing
The core insight here is not a new architecture so much as a change of perspective. Self-attention was always computing a weighted combination of values, and that weighting was always a similarity function. By choosing a similarity function that factors as a dot product of feature maps, matrix associativity does the rest, turning a quadratic computation into a linear one, and making the causal version a literal RNN.
The tradeoff is that you lose the ability to exactly replicate softmax attention, and the feature map choice matters for performance. But the paper shows this is a small price to pay for orders-of-magnitude faster inference.
Full paper: arXiv:2006.16236 | Code: linear-transformers.com