In the paper “Self-attention Does Not Need O(n^2) Memory,” the Google team introduces simple algorithms for attention and self-attention that require only constant memory and logarithmic memory, respectively. At a sequence length of 16384, the approach can reduce the self-attention memory overhead by 59x for inference and by 32x for differentiation.
Intro
Attention Algorithm:
- The attention operation on a single query produces a weighted sum of value vectors.
- The weights are determined by the softmax of the dot products between the query and the keys.
In the implementation of this, we have to:
- First compute and remember S(i) for all i, -> an O(n) time and memory complexity for each query.
- Transformers use self-attention, meaning for each element of the sequence, we need our query -> time and space complexity is O(n^2).
Main Algorithm
Single Attention Case
First Step:
Second Step:
- Initialize vectors v* and scalar s* with 0.
- Loop over k_n = [k1, k2, ..kn] and v_n = [v1, v2, …vn]. For each k_i and v_i, compute s_i.
- For each s_i, update v* and s*.
- After all, divide v/s -> final result.
Self-Attention Case
To extend this algorithm to self-attention, just compute the results for all queries sequentially.
Numerical Stability Problem
Default and new attention are not numerically stable when using floating-point arithmetic. For example, for scores ≥ 89, the exponentiation results in inf (for bfloat16 and float32).
To resolve this problem, they invented an additional scalar - m*:
- Which keeps track of the maximum score that the incremental algorithm has seen so far.
- Renormalize the sums of exponentiated values as needed.