Skip to content

Flash Attention 2

1. Overview

FlashAttention-2 is an improved version of FlashAttention that achieves 2x speedup over FlashAttention-1 through better GPU utilization. It maintains the same exact attention computation while being even faster and more efficient.

Key improvement: Better work partitioning across GPU threads to reduce idle time and maximize hardware utilization.



2. What Was Wrong with FlashAttention-1?

Despite being much faster than standard attention, FlashAttention-1 had suboptimal GPU utilization:

Problem 1: Poor Work Partitioning

  • Each thread block processed one query tile across all key/value tiles
  • Led to unbalanced workload and thread block idle time
  • Didn't fully saturate GPU compute resources

Problem 2: Non-Coalesced Memory Accesses

  • Memory accesses weren't optimally aligned for GPU memory coalescing
  • Caused unnecessary memory bandwidth waste

Problem 3: Limited Parallelism

  • Parallelism was only across batch, heads, and query sequence
  • Didn't parallelize across key/value sequence dimension


3. Key Improvements in FlashAttention-2

3.1 Better Parallelism Strategy

FlashAttention-1: Parallelize over (batch, heads, query_tiles)

Thread Block 1 → processes Q_tile_1 across all K,V tiles
Thread Block 2 → processes Q_tile_2 across all K,V tiles

FlashAttention-2: Parallelize over (batch, heads, query_tiles, kv_tiles)

Thread Block 1 → processes (Q_tile_1, K_tile_1, V_tile_1)
Thread Block 2 → processes (Q_tile_1, K_tile_2, V_tile_2)
Thread Block 3 → processes (Q_tile_2, K_tile_1, V_tile_1)

Benefit: More thread blocks doing work simultaneously → better GPU occupancy → less idle time


3.2 Improved Work Partitioning Within Thread Blocks

FlashAttention-1: Each warp handled different queries within a tile - Led to imbalanced work when softmax required different amounts of computation

FlashAttention-2: Each warp handles same query, split across K dimension - More balanced work distribution - Better load balancing across warps


3.3 Memory Access Optimizations

  • Improved memory coalescing patterns
  • Better cache utilization
  • Reduced redundant memory loads


4. Performance Impact

Speedup Over FlashAttention-1

  • ~2x faster on average for typical sequence lengths
  • Up to 2.3x on A100 GPUs for long sequences
  • Better scaling with sequence length

GPU Utilization

  • FlashAttention-1: ~35-50% of peak FLOPS
  • FlashAttention-2: ~50-70% of peak FLOPS

Memory Efficiency

  • Same \(O(N \cdot B)\) memory complexity
  • Better bandwidth utilization due to improved access patterns


5. Implementation Details

Thread Block Structure

# Conceptual partitioning
for batch_idx in batches:
    for head_idx in heads:
        for q_tile_idx in query_tiles:
            for kv_tile_idx in kv_tiles:  # NEW: also parallelize here
                # Each (q_tile, kv_tile) pair gets its own thread block
                thread_block.process(Q[q_tile_idx], K[kv_tile_idx], V[kv_tile_idx])
                # Accumulate partial results

Synchronization

  • Requires careful synchronization when accumulating partial outputs
  • Uses atomic operations or reduction trees to combine results from different KV tiles


6. Interview Questions

Q1: What's the main difference between FlashAttention-1 and FlashAttention-2?

Answer: FlashAttention-2 improves parallelism by also parallelizing across the key/value sequence dimension, not just query sequence. This means more thread blocks work simultaneously, reducing idle time and achieving ~2x speedup while computing the exact same result.


Q2: Does FlashAttention-2 change the algorithm or just the implementation?

Answer: It's purely an implementation improvement. The algorithm (tiling, online softmax, kernel fusion) remains the same. FlashAttention-2 just distributes work more efficiently across GPU threads to maximize hardware utilization.


Q3: Why does parallelizing across KV tiles improve performance?

Answer: In FlashAttention-1, each thread block processes one Q tile sequentially across all KV tiles. This limits parallelism. FlashAttention-2 launches separate thread blocks for each (Q_tile, KV_tile) pair, enabling many more blocks to run concurrently, better saturating the GPU's compute resources.


Q4: What's the trade-off with this increased parallelism?

Answer: More synchronization overhead. Since multiple thread blocks now compute partial outputs for the same query tile (from different KV tiles), we need to carefully accumulate and normalize these partial results. However, the performance gain from parallelism far outweighs this cost.


Q5: How does FlashAttention-2 handle the accumulation of partial results?

Answer: Each thread block computes a partial attention output for its (Q_tile, KV_tile) pair along with partial softmax statistics (max, sum). These partials are then combined using: - Atomic operations, or - Reduction trees, or - Final pass to accumulate stored partials

The online softmax technique ensures correct normalization.


Q6: What GPU features does FlashAttention-2 rely on more heavily?

Answer: - High thread block occupancy - needs many concurrent blocks - Fast atomic operations - for accumulating partials - Shared memory bandwidth - still crucial like FA-1 - Warp-level primitives - for efficient intra-block communication

Modern GPUs (A100, H100) have better support for these, maximizing FA-2's benefits.


Q7: Why doesn't FlashAttention-2 achieve 100% of peak FLOPS?

Answer: Several factors: - Memory bandwidth still matters (can't fully hide all memory latency) - Synchronization overhead from accumulating partials - Load imbalance across thread blocks (some finish before others) - Non-uniform work per tile (softmax computation varies)

50-70% utilization is actually quite good for memory-intensive operations.


Q8: How does sequence length affect FA-2's speedup over FA-1?

Answer: FlashAttention-2's advantage increases with sequence length: - Longer sequences → more tiles → more parallelism opportunities - Better amortization of synchronization overhead - At very short sequences, FA-1 and FA-2 are similar (not enough parallelism to exploit)


Q9: Can FlashAttention-2's techniques be applied to other operations?

Answer: Yes! The key insight—parallelizing over both input and computation dimensions—applies to any operation that processes tiles/blocks: - Other attention variants (sparse, local attention) - Convolutions with tiling - Matrix multiplications with block processing

The principle is: maximize parallel work to reduce GPU idle time.


Q10: What's the memory complexity of FlashAttention-2?

Answer: Same as FlashAttention-1: \(O(N \cdot B)\) where \(B\) is tile size. The improvement is in speed (better parallelism and memory access patterns), not memory usage. Both avoid materializing the full \(N \times N\) attention matrix.



7. FlashAttention-1 vs FlashAttention-2 Summary

Aspect FlashAttention-1 FlashAttention-2
Parallelism Over batch, heads, query tiles Over batch, heads, query tiles, KV tiles
Work per block One Q tile × all KV tiles One (Q tile, KV tile) pair
GPU utilization 35-50% of peak FLOPS 50-70% of peak FLOPS
Speedup vs standard 2-4x 4-8x
Speedup vs FA-1 - ~2x
Memory complexity \(O(N \cdot B)\) \(O(N \cdot B)\)
Algorithm Tiling + online softmax + fusion Same
Synchronization Simpler More complex (atomic/reduction)


8. Key Takeaways for Interviews

  1. Main idea: Same algorithm, better parallelism by also parallelizing across KV dimension
  2. Performance: ~2x faster than FA-1 through better GPU utilization
  3. Trade-off: More synchronization overhead, but worth it for the speedup
  4. Memory: Same \(O(N \cdot B)\) complexity, just faster execution
  5. When it matters most: Long sequences where more parallelism can be exploited
  6. Hardware dependency: Benefits scale with GPU's ability to run many thread blocks concurrently