Skip to content

Flash Attention

1. Overview

FlashAttention is a fast and memory-efficient attention algorithm that computes exact attention without materializing the full \(N \times N\) attention matrix. It's especially critical for long sequences (4k+ tokens) in modern LLMs.

Key insight: The bottleneck in attention isn't compute - it's memory bandwidth (moving data between GPU memory hierarchies).



2. Why Standard Attention is Slow

Standard attention formula: $$ \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$


Problem 1: Quadratic Memory Growth

For sequence length \(N = 16{,}384\) in FP16:

  • Attention matrix: \(N^2 = 268M\) elements
  • Memory: \(268M \times 2\) bytes \(≈ 512\) MB (per layer, per head!)

Problem 2: Excessive Memory Traffic

Standard attention performs multiple memory-heavy steps:

  1. Compute \(QK^T\) → write to global memory
  2. Read \(QK^T\) → apply softmax → write back
  3. Read softmax output → compute with \(V\) → write output

Result: GPUs become memory-bound, not compute-bound.


Problem 3: Numerical Instability in FP16

  • Large values in \(QK^T\) cause overflow in \(e^x\)
  • Small values underflow to zero
  • Standard attention often requires FP32, increasing memory usage


3. How FlashAttention Works

FlashAttention uses three key techniques:

3.1 Tiling

Split Q, K, V into small tiles that fit in GPU shared memory (SRAM).

Example:

  • Sequence length: \(N = 16{,}384\)
  • Tile size: \(B = 128\)
  • Memory per tile: \(128 \times 128 = 16{,}384\) elements (vs. \(268M\) for full matrix)
# Conceptual tiling
for q_tile in Q_tiles:
    for k_tile, v_tile in zip(K_tiles, V_tiles):
        partial_scores = q_tile @ k_tile.T
        # accumulate incrementally

3.2 Kernel Fusion

Fuse all operations into a single kernel to keep intermediate results in fast shared memory:

  1. Matrix multiplication (\(Q \cdot K^T\))
  2. Scaling (\(1/\sqrt{d}\))
  3. Softmax
  4. Weighted sum with \(V\)

Standard attention writes/reads from global memory between each step. FlashAttention does everything in one pass.


3.3 Online Softmax

Compute softmax incrementally across tiles without storing the full attention matrix.

Numerically stable approach:

  1. Maintain running maximum \(m\) across tiles
  2. Compute: \(e^{x_i - m}\) (prevents overflow)
  3. Maintain running sum of exponentials
  4. Accumulate weighted output incrementally

Example with 2 tiles:

Tile 1: [0.1, 0.5, 0.3], Tile 2: [0.2, 0.4, 0.1]

Processing:

  1. Tile 1: \(m = 0.5\), shifted exps: \([e^{-0.4}, e^{0}, e^{-0.2}]\), running sum \(s_1\)
  2. Tile 2: update \(m\), reweight previous results, add new exps, update sum \(s_2\)
  3. Final: divide accumulated output by \(s_2\)

Result: Exact same output as standard attention, but in FP16/BF16 without overflow.



4. Performance Impact

Memory Complexity

  • Standard: \(O(N^2)\)
  • FlashAttention: \(O(N \cdot B)\) where \(B\) is tile size

Speedup

  • 2–4x faster for long sequences on modern GPUs
  • Enables 2–4x longer sequences or larger batch sizes

Usage

from flash_attn import flash_attn_func

# q, k, v shape: (batch, seq_len, num_heads, head_dim)
output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)

When FlashAttention Helps Most

✅ Long sequences (2k+ tokens)
✅ FP16/BF16 precision
✅ Modern NVIDIA GPUs with fast shared memory

❌ Very short sequences
❌ CPU-based inference
❌ Custom attention patterns not supported by the kernels



5. Interview Questions

Q1: Why is FlashAttention faster than standard attention?

Answer: The bottleneck is memory bandwidth, not compute. Standard attention writes intermediate results (attention matrix, softmax output) to slow GPU global memory and reads them back multiple times. FlashAttention uses tiling and kernel fusion to keep all intermediate computations in fast shared memory, drastically reducing memory traffic.


Q2: Does FlashAttention approximate attention?

Answer: No, it computes exact attention. It produces identical results to standard attention by using online softmax to correctly compute the softmax normalization across tiles without storing the full attention matrix.


Q3: Explain online softmax. Why is it needed?

Answer: When processing tiles, we can't store the full \(N \times N\) attention matrix. Online softmax maintains a running maximum and running sum across tiles to compute the exact softmax incrementally. This also provides numerical stability in FP16/BF16 by shifting scores before exponentiation to prevent overflow: \(e^{x_i - m}\) instead of \(e^{x_i}\).


Q4: What is the memory complexity of FlashAttention vs standard attention?

Answer:

  • Standard attention: \(O(N^2)\) to store full attention matrix
  • FlashAttention: \(O(N \cdot B)\) where \(B\) is tile size (typically 128-256)
  • For \(N=16k\), this reduces memory from ~512MB to ~4-8MB per head

Q5: Can FlashAttention be used for any attention mechanism?

Answer: FlashAttention works best for standard scaled dot-product attention. Variants exist for:

  • Causal (autoregressive) attention ✅
  • Cross-attention ✅
  • Sparse attention patterns ⚠️ (limited support, depends on sparsity structure)

Custom attention patterns may require specialized kernels.


Q6: Why does FlashAttention require modern GPUs?

Answer: FlashAttention relies on:

  1. Fast shared memory (SRAM) - to store tiles and perform fused operations
  2. High memory bandwidth - to maximize benefit from reduced memory traffic
  3. Tensor cores - for fast matrix multiplications

Older GPUs or CPUs don't have the same memory hierarchy, so the benefits are minimal.


Q7: Walk through how FlashAttention processes a single tile pair.

Answer:

  1. Load \(Q_{tile}\) and \(K_{tile}\) into shared memory
  2. Compute scores: \(S = Q_{tile} \cdot K_{tile}^T / \sqrt{d}\)
  3. Track running max \(m\) for numerical stability
  4. Compute: \(e^{S - m}\) (stays in shared memory)
  5. Update running sum for normalization
  6. Load \(V_{tile}\), compute weighted sum, accumulate to output
  7. Move to next tile, repeat

All intermediate values stay in fast SRAM, not global memory.


Q8: What trade-offs does FlashAttention make?

Answer:

  • ✅ Gains: 2-4x speedup, drastically reduced memory
  • ⚠️ Complexity: More complex implementation than standard attention
  • ⚠️ Flexibility: Limited support for custom sparse attention patterns
  • ⚠️ Hardware: Requires modern GPUs to realize full benefits

The trade-offs are generally worth it for production LLM serving and training.


Q9: How does tiling affect the computational complexity?

Answer: Tiling doesn't change the computational complexity (still \(O(N^2)\) FLOPs), but it changes the I/O complexity:

  • Standard: \(O(N^2)\) memory reads/writes
  • FlashAttention: \(O(N^2/B)\) memory reads/writes, where \(B\) is tile size

Since memory bandwidth is the bottleneck, this provides significant speedup.


Q10: Can you explain the difference between shared memory and global memory?

Answer:

  • Shared memory (SRAM): Fast (~20 TB/s), small (~100 KB per SM), explicitly managed
  • Global memory (HBM): Slower (~1-2 TB/s), large (16-80 GB), high latency

FlashAttention keeps working data in shared memory to minimize expensive global memory accesses. This is the key to its performance gains.


6. Key Takeaways for Interviews

  1. Main idea: Avoid materializing the \(N \times N\) attention matrix by using tiling and kernel fusion
  2. Core techniques: Tiling + Kernel Fusion + Online Softmax
  3. Why it's fast: Reduces memory bandwidth usage (the real bottleneck)
  4. Memory savings: \(O(N^2) → O(N \cdot B)\)
  5. Exact computation: Not an approximation—produces identical results to standard attention
  6. Numerical stability: Online softmax enables stable FP16/BF16 computation for long sequences