Gradient Checkpointing for LLM Training¶
1. Overview¶
Gradient checkpointing (also called activation checkpointing or activation recomputation) is a memory optimization technique that trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them.
The Memory Problem¶
During backpropagation, PyTorch stores all intermediate activations from the forward pass to compute gradients. For large models, activations consume more memory than parameters.
Memory breakdown for 7B LLaMA (batch_size=8, seq_len=2048):
| Component | Memory | Percentage |
|---|---|---|
| Parameters (BF16) | 14 GB | 12% |
| Gradients | 14 GB | 12% |
| Optimizer states | 56 GB | 47% |
| Activations | ~35 GB | 29% |
| Total | ~119 GB | 100% |
Without gradient checkpointing, even a single forward/backward pass can exceed GPU memory.
2. How Backpropagation Uses Activations¶
2.1 Standard Backpropagation¶
# Forward pass
x1 = layer1(x0) # Store x1 for backward
x2 = layer2(x1) # Store x2 for backward
x3 = layer3(x2) # Store x3 for backward
loss = criterion(x3, y)
# Backward pass
grad_x3 = grad_loss
grad_x2 = layer3.backward(grad_x3, x2) # Uses stored x2
grad_x1 = layer2.backward(grad_x2, x1) # Uses stored x1
grad_x0 = layer1.backward(grad_x1, x0) # Uses stored x0
Memory: All activations \(x_1, x_2, x_3, ...\) must be kept in memory.
For a Transformer with L layers, each layer stores:
- Query, Key, Value projections
- Attention scores
- Attention outputs
- MLP intermediate activations
Total activation memory scales as: \(O(L \times B \times S \times d)\)
- \(L\) = number of layers
- \(B\) = batch size
- \(S\) = sequence length
- \(d\) = hidden dimension
2.2 Why Activations Dominate¶
Example: LLaMA-7B single layer (batch=4, seq_len=2048)
| Component | Shape | Memory (BF16) |
|---|---|---|
| Parameters | (4096, 4096) | 32 MB |
| Attention activations | (4, 2048, 4096) | 64 MB |
| MLP activations | (4, 2048, 11008) | 176 MB |
| Per-layer total | - | 240 MB |
| 32 layers | - | 7.7 GB |
With batch_size=8, this doubles to ~15 GB just for activations.
3. Gradient Checkpointing: Core Idea¶
3.1 The Trade-off¶
Instead of storing all activations:
- Store activations only at checkpoint boundaries (e.g., every K layers)
- During backward pass, recompute activations between checkpoints
# Forward pass (checkpoint every 2 layers)
x1 = layer1(x0)
x2 = layer2(x1) # Checkpoint: Save x2
del x1 # Free memory
x3 = layer3(x2)
x4 = layer4(x3) # Checkpoint: Save x4
del x3 # Free memory
# Backward pass for layers 3-4
x3 = layer3(x2) # RECOMPUTE x3
grad_x4 = grad_loss
grad_x3 = layer4.backward(grad_x4, x3)
grad_x2 = layer3.backward(grad_x3, x2)
# Backward pass for layers 1-2
x1 = layer1(x0) # RECOMPUTE x1
grad_x2 = ...
grad_x1 = layer2.backward(grad_x2, x1)
grad_x0 = layer1.backward(grad_x1, x0)
Result:
- Memory: \(O(L/K)\) instead of \(O(L)\)
- Compute: +33% (one extra forward pass per checkpoint segment)
4. Implementation Strategies¶
4.1 PyTorch Native Checkpointing¶
import torch
from torch.utils.checkpoint import checkpoint
class TransformerLayer(nn.Module):
def forward(self, x):
# Normal forward pass
attn_out = self.attention(x)
mlp_out = self.mlp(attn_out)
return mlp_out
class TransformerWithCheckpointing(nn.Module):
def __init__(self, num_layers=32):
super().__init__()
self.layers = nn.ModuleList([TransformerLayer() for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
# Checkpoint each layer
x = checkpoint(layer, x, use_reentrant=False)
return x
Key parameter: use_reentrant=False
- Newer, more memory-efficient implementation
- Avoids issues with control flow and RNGs
- Recommended for all new code
4.2 Selective Checkpointing¶
Not all layers benefit equally. Checkpoint strategically:
class SelectiveCheckpointTransformer(nn.Module):
def __init__(self, num_layers=32, checkpoint_every=4):
super().__init__()
self.layers = nn.ModuleList([TransformerLayer() for _ in range(num_layers)])
self.checkpoint_every = checkpoint_every
def forward(self, x):
for i, layer in enumerate(self.layers):
if i % self.checkpoint_every == 0:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x) # No checkpointing
return x
Trade-off tuning:
checkpoint_every=1: Max memory savings (~50%), +33% computecheckpoint_every=4: Moderate savings (~30%), +10% computecheckpoint_every=∞: No savings, baseline speed
4.3 HuggingFace Transformers¶
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Train normally
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for batch in dataloader:
outputs = model(**batch)
outputs.loss.backward()
optimizer.step()
Under the hood: Checkpoints every Transformer block by default.
4.4 DeepSpeed Integration¶
ds_config = {
"train_batch_size": 16,
"gradient_accumulation_steps": 4,
"activation_checkpointing": {
"partition_activations": True, # Shard across GPUs
"cpu_checkpointing": True, # Offload to CPU
"contiguous_memory_optimization": True,
"number_checkpoints": 32, # Checkpoint every layer
}
}
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config=ds_config,
)
Advanced features:
- Activation partitioning: Shard checkpoints across GPUs
- CPU offloading: Store checkpoints in CPU memory
- Contiguous memory: Reduce fragmentation
5. Memory Savings Analysis¶
5.1 Theoretical Bounds¶
For a model with \(L\) layers and checkpointing every \(K\) layers:
Activation memory: $$ M_{\text{checkpoint}} = O\left(\frac{L}{K} + K\right) \times B \times S \times d $$
Optimal checkpointing interval: $$ K_{\text{opt}} = \sqrt{L} $$
This minimizes total memory while keeping recomputation reasonable.
Example: 32-layer model
- No checkpointing: \(M \propto 32\)
- Checkpoint every 6 layers: \(M \propto 32/6 + 6 \approx 11.3\) (65% reduction)
- Optimal (\(\sqrt{32} \approx 6\)): Best memory/compute trade-off
5.2 Real-World Measurements¶
LLaMA-7B (32 layers, batch=4, seq_len=2048) on A100:
| Strategy | Activation Memory | Peak Memory | Throughput | Compute Overhead |
|---|---|---|---|---|
| No checkpointing | 15.2 GB | 85 GB | 1.0× | 0% |
| Checkpoint every 8 layers | 5.8 GB | 75 GB | 0.92× | +8% |
| Checkpoint every 4 layers | 4.1 GB | 70 GB | 0.85× | +15% |
| Checkpoint every layer | 1.9 GB | 62 GB | 0.75× | +33% |
Sweet spot: Checkpoint every 4-8 layers for most use cases.
5.3 Scaling with Sequence Length¶
Activation memory scales quadratically with sequence length due to attention:
Where \(H\) = number of attention heads.
Impact:
| Sequence Length | Activation Memory (32 layers) | With Checkpointing (every layer) |
|---|---|---|
| 512 | 1.2 GB | 0.3 GB |
| 2048 | 15.2 GB | 1.9 GB |
| 8192 | 240 GB | 30 GB |
| 32768 | 3.8 TB | 480 GB |
For long-context training, gradient checkpointing is essential.
6. Best Practices¶
6.1 When to Use Gradient Checkpointing¶
✅ Use when:
- Training large models (>1B parameters)
- Long sequence lengths (>2048)
- Limited GPU memory
- Batch size is already at minimum (can't reduce further)
❌ Don't use when:
- Small models (<100M parameters)
- GPU memory is not a bottleneck
- Inference (no backward pass needed)
- Extremely tight latency requirements
6.2 Combining with Other Techniques¶
Recommended stack for memory efficiency:
# 1. Mixed precision
from torch.cuda.amp import autocast, GradScaler
# 2. Gradient checkpointing
model.gradient_checkpointing_enable()
# 3. Gradient accumulation
accumulation_steps = 4
# 4. Memory-efficient optimizer
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-4)
scaler = GradScaler()
for i, batch in enumerate(dataloader):
with autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Memory savings stack:
- Mixed precision: -40%
- Gradient checkpointing: -50% (activations)
- 8-bit optimizer: -75% (optimizer states)
- Combined: Enables 3-4× larger models
6.3 Debugging Checkpointing Issues¶
Common problems:
-
RNG state mismatch
# Problem: Dropout behaves differently during recomputation # Solution: Use use_reentrant=False checkpoint(fn, x, use_reentrant=False) -
In-place operations
# Problem: In-place ops break gradient computation x += residual # ❌ Breaks checkpointing # Solution: Use out-of-place operations x = x + residual # ✅ Works with checkpointing -
Custom CUDA kernels
# Must define backward pass explicitly class MyCheckpointedFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return my_cuda_kernel_forward(x) @staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors return my_cuda_kernel_backward(grad_output, x)
7. Advanced Topics¶
7.1 CPU Offloading (DeepSpeed)¶
For extreme memory constraints, offload checkpoints to CPU:
# DeepSpeed config
"activation_checkpointing": {
"cpu_checkpointing": True, # Store checkpoints in CPU RAM
"contiguous_memory_optimization": True,
}
Trade-off:
- Memory: Can handle arbitrarily large models
- Speed: ~2-3× slower due to PCIe transfers
Use case: Training 70B+ models on consumer GPUs.
7.2 Selective Activation Checkpointing (SAC)¶
Idea: Only checkpoint memory-heavy operations (attention), not cheap ones (LayerNorm).
class SmartCheckpointLayer(nn.Module):
def forward(self, x):
# Checkpoint expensive attention
attn_out = checkpoint(self.attention, x, use_reentrant=False)
# Don't checkpoint cheap operations
norm_out = self.layer_norm(attn_out)
# Checkpoint expensive MLP
mlp_out = checkpoint(self.mlp, norm_out, use_reentrant=False)
return mlp_out
Benefit: 80% of memory savings with only 15% compute overhead.
7.3 Gradient Checkpointing + FSDP¶
Fully Sharded Data Parallel (FSDP) + checkpointing enables massive models:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Define which layers to checkpoint
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={TransformerBlock},
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
use_orig_params=True, # Enable gradient checkpointing
)
# Enable checkpointing
for module in model.modules():
if isinstance(module, TransformerBlock):
module.gradient_checkpointing = True
Enables training 175B+ models on 8× A100s.
10. Quick Reference¶
Decision Tree¶
Is GPU memory sufficient?
├─ Yes → Don't use checkpointing (faster training)
└─ No → Can you reduce batch size?
├─ No (already batch_size=1) → Enable checkpointing
└─ Yes → Try batch reduction first
└─ Still OOM? → Enable checkpointing
Checkpointing Intervals¶
| Model Size | Recommended Interval | Memory Savings | Compute Overhead |
|---|---|---|---|
| <1B | Every 8-16 layers | 30-40% | +5-10% |
| 1-10B | Every 4-8 layers | 50-60% | +10-20% |
| 10-100B | Every 1-2 layers | 70-80% | +25-35% |
| >100B | Every layer + CPU offload | 85-90% | +50-100% |
Compatibility Matrix¶
| Technique | Compatible | Notes |
|---|---|---|
| Mixed precision | ✅ | No interaction |
| Flash Attention | ✅ | Complementary (use both) |
| FSDP | ✅ | Enables massive models |
| DDP | ✅ | Works normally |
| DeepSpeed ZeRO | ✅ | Can combine with CPU offload |
| Model parallelism | ✅ | Checkpoint per device |
| Quantization | ✅ | Checkpoint quantized activations |
11. Further Reading¶
- Gradient Checkpointing Original Paper: Training Deep Nets with Sublinear Memory Cost (2016)
- PyTorch Documentation: torch.utils.checkpoint
- Flash Attention: Flash Attention 2 (2023)
- DeepSpeed Activation Checkpointing: DeepSpeed Docs
- Megatron-LM: Efficient Large-Scale Language Model Training (2021)