Activation Checkpointing: Trading Computation for Memory in Deep Learning
Introduction
Modern deep learning models have grown to incredible sizes – from millions to billions of parameters. Training these massive models presents a significant challenge: GPU memory consumption.
During training, neural networks need to store not just model parameters but also intermediate activation values for backpropagation, creating a memory bottleneck that can prevent researchers from training larger models or using reasonable batch sizes.

Activation Checkpointing (also called Gradient Checkpointing) is an elegant solution to this challenge. Rather than storing all intermediate activations during the forward pass, this technique strategically saves only a subset at “checkpoint” locations, then regenerates the others when needed during backpropagation.
This approach offers a compelling tradeoff: trading additional computation (recomputing activations) for significant memory savings. For many large-scale models like transformers, this technique has become essential, enabling training that would otherwise be impossible on available hardware.

How Activation Checkpointing Works
The Memory Problem in Backpropagation
To understand activation checkpointing, we first need to recognize why training neural networks is memory-intensive:
- Forward Pass: As input data flows through the network, each layer produces intermediate activation outputs.
- Storage Requirement: These activations must be stored in memory because they’re needed to compute gradients during backpropagation.
- Memory Bottleneck: For deep networks, storing all these activations can consume gigabytes of GPU memory, especially with large batch sizes.
For a network with L layers and batch size N, memory consumption for activations grows as O(N × L), potentially exceeding available GPU memory for large models.
Core Mechanics of Checkpointing
Activation checkpointing addresses this problem through a clever strategy:
- Strategic Segmentation
The network is divided into segments, with activation values stored only at segment boundaries (“checkpoints”). - Selective Storage During Forward Pass
- Only checkpoint activations are retained in memory
- Intermediate activations between checkpoints are computed and then discarded
- Recomputation During Backward Pass
- When backpropagation reaches a segment, the forward pass for that segment is recomputed
- This regenerates the previously discarded intermediate activations
- Gradients are then calculated using these recomputed values

This approach dramatically reduces “peak memory usage” at the cost of performing additional forward pass computations.
The Memory-Computation Tradeoff
Let’s quantify the benefits and costs:
- Without Checkpointing:
- Memory usage for activations: O(N × L)
- Computation cost: 1 forward pass + 1 backward pass
- With Checkpointing:
- Memory usage: Can be reduced to O(N × √L) by using √L checkpoints
- Computation cost: 1 forward pass + (1 forward pass during backprop) + 1 backward pass
- Computational overhead: Up to 50% increase (maximum is doubling computation)

This tradeoff is highly favorable for memory-constrained scenarios, where the alternative would be to reduce batch size or model size.
Optimal Checkpoint Placement
The placement of checkpoints significantly impacts both memory savings and computational overhead. Three common strategies are:
- Uniform Checkpointing
Place checkpoints at equal intervals throughout the network. With √L evenly spaced checkpoints, memory complexity reduces to O(N × √L). - Nested Checkpointing
Apply checkpointing recursively, further reducing memory requirements at the cost of more recomputation. - Selective Checkpointing
Place checkpoints at layers with the largest activation memory footprints.

Research has shown that optimal checkpoint placement is often model-specific. For transformer architectures, checkpointing transformer blocks is a common practice that balances efficiency and simplicity.
Implementation in Popular Frameworks
PyTorch
PyTorch provides native support for activation checkpointing through the torch.utils.checkpoint
module:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
class CheckpointedTransformerBlock(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.attention = SelfAttention(dim, heads)
self.feed_forward = FeedForward(dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def _forward(self, x):
# Regular transformer block implementation
x = x + self.attention(self.norm1(x))
x = x + self.feed_forward(self.norm2(x))
return x
def forward(self, x):
return checkpoint.checkpoint(self._forward, x)
# Usage in a transformer model
class MemoryEfficientTransformer(nn.Module):
def __init__(self, dim, depth, heads):
super().__init__()
self.layers = nn.ModuleList([
CheckpointedTransformerBlock(dim, heads)
for _ in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
TensorFlow/Keras
TensorFlow offers similar functionality through its tf.recompute_grad
decorator:
import tensorflow as tf
@tf.recompute_grad
def checkpointed_layer(x):
# Layer operations
return output
# Usage in a model
def model(x):
# ... other layers
x = checkpointed_layer(x)
# ... more layers
return x
HuggingFace Transformers
The popular HuggingFace library leverages checkpointing for training large language models:
from transformers import GPT2Config, GPT2LMHeadModel
# Enable gradient checkpointing in the config
config = GPT2Config.from_pretrained('gpt2')
config.gradient_checkpointing = True
# Create model with checkpointing enabled
model = GPT2LMHeadModel(config)
Real-World Impact: Case Studies
Training GPT-3 Scale Models
The 175B parameter GPT-3 model would be practically untainable without activation checkpointing. With standard techniques, just storing the activations would require over 1TB of GPU memory. Checkpointing reduced this requirement to a manageable level that fit within a distributed training setup.
Vision Transformers (ViT)
For large-scale vision transformers processing high-resolution images, activation checkpointing has enabled training of deeper architectures that would otherwise be memory-prohibitive, especially when fine-tuning with limited hardware.
Benchmarks and Performance Metrics
Model Size | Batch Size | Without Checkpointing | With Checkpointing | Computational Overhead |
---|---|---|---|---|
125M params | 32 | 16GB | 8GB | ~30% |
1.5B params | 16 | Out of memory (24GB GPU) | 20GB | ~35% |
6B params | 8 | Out of memory (40GB GPU) | 32GB | ~40% |
These numbers vary by model architecture and implementation, but illustrate the typical memory savings and computational trade-offs.
Advanced Techniques and Optimizations
Combining with Mixed Precision Training
Activation checkpointing pairs exceptionally well with mixed precision training (FP16/BF16). The combination of these techniques can provide:
– Memory reduction from checkpointing.
– Memory reduction from 16-bit storage.
– Computational speedups from 16-bit math*.
– Ability to train significantly larger models.
*: the speedup is not guaranteed on CUDA where a lot of math ops have to be in FP32 – but there are talks of that changing in later NVIDIA chips and updates.
Selective Activation Recomputation
Not all parts of a network are equally expensive to recompute. Sophisticated implementations selectively checkpoint based on:
– Computational complexity of layers
– Memory footprint of activations
– Critical path analysis in the computation graph
Integration with Model Parallelism
This becomes important for a distributed run. For extremely large models trained across multiple GPUs:
– Tensor parallelism splits individual operations across devices
– Pipeline parallelism splits the model into stages
– Activation checkpointing works within each GPU’s portion of the model
Alternatives and Complementary Approaches
While activation checkpointing is powerful, it’s often used alongside other memory optimization techniques:
- Activation Offloading: Moving activations to CPU memory temporarily
- Gradient Accumulation: Splitting batches into micro-batches to reduce memory requirements
- Reversible Layers: Special layer designs that allow activations to be reconstructed perfectly without recomputation
- CPU Offloading: Storing model parameters or optimizer states in CPU memory
When to Use Activation Checkpointing
Consider activation checkpointing when:
- You’re training very deep networks with many layers
- Your model is generating “CUDA out of memory” errors
- You need to increase batch size for better training dynamics
- You have computation capacity to spare but limited GPU memory
- You’re working with transformer-based architectures
Avoid activation checkpointing when:
- Training speed is the primary concern and memory is abundant
- Your model is already computationally bottlenecked
- The model is small enough to fit comfortably in memory
Conclusion
Activation checkpointing is an important algorithmic innovation enabling the current era of LLMs and deep neural networks. By trading computation for memory in a controlled manner, it has extended the frontier of what’s possible in DL research and applications.
As models continue to grow in size and complexity, expect checkpointing techniques to evolve further, with more sophisticated strategies for balancing memory usage and computational efficiency.
References and Further Reading
- Chen et al., “GradientCheckpoint: Trading Computation for Memory in Neural Network Training”
- Gruslys et al., “Memory-Efficient Backpropagation Through Time”
- Jain et al., “Checkmate: Breaking the Memory Wall with Optimal Tensor Rematerialization”
- PyTorch Documentation: torch.utils.checkpoint
- TensorFlow Documentation: tf.recompute_grad