FlashAttention - is not an approximate attention method; it’s more about carefully accounting for reads and writes to different levels of fast and slow memory.
Surprisingly, the time we spend on computation is much smaller than the time we spend on reading and writing to memory - proof.
Main goal - avoid reading and writing the attention matrix to and from HBM.
- Computing the softmax reduction without access to the whole input.
- Not storing the large intermediate attention matrix for the backward pass.
How to achieve this?
- Restructure the attention computation to split the input into blocks and make several passes over input blocks, thus incrementally performing the softmax reduction (also known as tiling).
- Store the softmax normalization factor from the forward pass to quickly recompute attention on-chip in the backward pass, which is faster than the standard approach of reading the intermediate attention matrix from HBM.
Background
GPU has Memory Hierarchy. GPUs have a massive number of threads to execute an operation (called a kernel). Each kernel loads inputs from HBM to registers and SRAM, computes, then writes outputs to HBM.
Depending on the balance of computation and memory accesses, operations can be classified as:
- Compute-bound: the time taken by the operation is determined by how many arithmetic operations there are, while time accessing HBM is much smaller. Typical examples are matrix multiply with a large inner dimension and convolution with a large number of channels.
- Memory-bound: the time taken by the operation is determined by the number of memory accesses, while time spent in computation is much smaller. Examples include most other operations: elementwise (e.g., activation, dropout), and reduction (e.g., sum, softmax, batch norm, layer norm).
Kernel fusion
The most common approach to accelerate memory-bound operations is kernel fusion: if there are multiple operations applied to the same input, the input can be loaded once from HBM, instead of multiple times for each operation.
Standard Attention Implementation
as we can see from the picture Q,V and O a way smaller than inner matrices , which in NxN(quadratic complexity on seq length).
Algorithm Require: Matrices Q, K, V ∈ R in HBM.
- Load Q, K by blocks from HBM, compute S = QK>, write S to HBM.
- Read S from HBM, compute P = softmax(S), write P to HBM.
- Load P and V by blocks from HBM, compute O = PV, write O to HBM.
- Return O.