Unboxing LLMs > loading...

July 3, 2023

Activation Checkpointing: Trading Computation for Memory in Deep Learning

Activation Checkpointing: Trading Computation for Memory in Deep Learning

Introduction

Modern deep learning models have grown to incredible sizes – from millions to billions of parameters. Training these massive models presents a significant challenge: GPU memory consumption.

During training, neural networks need to store not just model parameters but also intermediate activation values for backpropagation, creating a memory bottleneck that can prevent researchers from training larger models or using reasonable batch sizes.

Standard Training (High Memory)

Activation Checkpointing (also called Gradient Checkpointing) is an elegant solution to this challenge. Rather than storing all intermediate activations during the forward pass, this technique strategically saves only a subset at “checkpoint” locations, then regenerates the others when needed during backpropagation.

This approach offers a compelling tradeoff: trading additional computation (recomputing activations) for significant memory savings. For many large-scale models like transformers, this technique has become essential, enabling training that would otherwise be impossible on available hardware.

With Checkpointing (Lower Memory)

How Activation Checkpointing Works

The Memory Problem in Backpropagation

To understand activation checkpointing, we first need to recognize why training neural networks is memory-intensive:

  1. Forward Pass: As input data flows through the network, each layer produces intermediate activation outputs.
  2. Storage Requirement: These activations must be stored in memory because they’re needed to compute gradients during backpropagation.
  3. Memory Bottleneck: For deep networks, storing all these activations can consume gigabytes of GPU memory, especially with large batch sizes.

For a network with L layers and batch size N, memory consumption for activations grows as O(N × L), potentially exceeding available GPU memory for large models.

Core Mechanics of Checkpointing

Activation checkpointing addresses this problem through a clever strategy:

  1. Strategic Segmentation
    The network is divided into segments, with activation values stored only at segment boundaries (“checkpoints”).
  2. Selective Storage During Forward Pass
    • Only checkpoint activations are retained in memory
    • Intermediate activations between checkpoints are computed and then discarded
  3. Recomputation During Backward Pass
    • When backpropagation reaches a segment, the forward pass for that segment is recomputed
    • This regenerates the previously discarded intermediate activations
    • Gradients are then calculated using these recomputed values
Sequence Diagram

This approach dramatically reduces “peak memory usage” at the cost of performing additional forward pass computations.

The Memory-Computation Tradeoff

Let’s quantify the benefits and costs:

  • Without Checkpointing:
    • Memory usage for activations: O(N × L)
    • Computation cost: 1 forward pass + 1 backward pass
  • With Checkpointing:
    • Memory usage: Can be reduced to O(N × √L) by using √L checkpoints
    • Computation cost: 1 forward pass + (1 forward pass during backprop) + 1 backward pass
    • Computational overhead: Up to 50% increase (maximum is doubling computation)
Memory Usage Comparison

This tradeoff is highly favorable for memory-constrained scenarios, where the alternative would be to reduce batch size or model size.

Optimal Checkpoint Placement

The placement of checkpoints significantly impacts both memory savings and computational overhead. Three common strategies are:

  1. Uniform Checkpointing
    Place checkpoints at equal intervals throughout the network. With √L evenly spaced checkpoints, memory complexity reduces to O(N × √L).
  2. Nested Checkpointing
    Apply checkpointing recursively, further reducing memory requirements at the cost of more recomputation.
  3. Selective Checkpointing
    Place checkpoints at layers with the largest activation memory footprints.
Uniform Checkpointing

Research has shown that optimal checkpoint placement is often model-specific. For transformer architectures, checkpointing transformer blocks is a common practice that balances efficiency and simplicity.

PyTorch

PyTorch provides native support for activation checkpointing through the torch.utils.checkpoint module:

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class CheckpointedTransformerBlock(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attention = SelfAttention(dim, heads)
        self.feed_forward = FeedForward(dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
    def _forward(self, x):
        # Regular transformer block implementation
        x = x + self.attention(self.norm1(x))
        x = x + self.feed_forward(self.norm2(x))
        return x
        
    def forward(self, x):
        return checkpoint.checkpoint(self._forward, x)

# Usage in a transformer model
class MemoryEfficientTransformer(nn.Module):
    def __init__(self, dim, depth, heads):
        super().__init__()
        self.layers = nn.ModuleList([
            CheckpointedTransformerBlock(dim, heads)
            for _ in range(depth)
        ])
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

TensorFlow/Keras

TensorFlow offers similar functionality through its tf.recompute_grad decorator:

import tensorflow as tf

@tf.recompute_grad
def checkpointed_layer(x):
    # Layer operations
    return output

# Usage in a model
def model(x):
    # ... other layers
    x = checkpointed_layer(x)
    # ... more layers
    return x

HuggingFace Transformers

The popular HuggingFace library leverages checkpointing for training large language models:

from transformers import GPT2Config, GPT2LMHeadModel

# Enable gradient checkpointing in the config
config = GPT2Config.from_pretrained('gpt2')
config.gradient_checkpointing = True

# Create model with checkpointing enabled
model = GPT2LMHeadModel(config)

Real-World Impact: Case Studies

Training GPT-3 Scale Models

The 175B parameter GPT-3 model would be practically untainable without activation checkpointing. With standard techniques, just storing the activations would require over 1TB of GPU memory. Checkpointing reduced this requirement to a manageable level that fit within a distributed training setup.

Vision Transformers (ViT)

For large-scale vision transformers processing high-resolution images, activation checkpointing has enabled training of deeper architectures that would otherwise be memory-prohibitive, especially when fine-tuning with limited hardware.

Benchmarks and Performance Metrics

Model SizeBatch SizeWithout CheckpointingWith CheckpointingComputational Overhead
125M params3216GB8GB~30%
1.5B params16Out of memory (24GB GPU)20GB~35%
6B params8Out of memory (40GB GPU)32GB~40%

These numbers vary by model architecture and implementation, but illustrate the typical memory savings and computational trade-offs.

Advanced Techniques and Optimizations

Combining with Mixed Precision Training

Activation checkpointing pairs exceptionally well with mixed precision training (FP16/BF16). The combination of these techniques can provide:
– Memory reduction from checkpointing.
– Memory reduction from 16-bit storage.
– Computational speedups from 16-bit math*.
– Ability to train significantly larger models.
*: the speedup is not guaranteed on CUDA where a lot of math ops have to be in FP32 – but there are talks of that changing in later NVIDIA chips and updates.

Selective Activation Recomputation

Not all parts of a network are equally expensive to recompute. Sophisticated implementations selectively checkpoint based on:
– Computational complexity of layers
– Memory footprint of activations
– Critical path analysis in the computation graph

Integration with Model Parallelism

This becomes important for a distributed run. For extremely large models trained across multiple GPUs:
– Tensor parallelism splits individual operations across devices
– Pipeline parallelism splits the model into stages
– Activation checkpointing works within each GPU’s portion of the model

Alternatives and Complementary Approaches

While activation checkpointing is powerful, it’s often used alongside other memory optimization techniques:

  1. Activation Offloading: Moving activations to CPU memory temporarily
  2. Gradient Accumulation: Splitting batches into micro-batches to reduce memory requirements
  3. Reversible Layers: Special layer designs that allow activations to be reconstructed perfectly without recomputation
  4. CPU Offloading: Storing model parameters or optimizer states in CPU memory

When to Use Activation Checkpointing

Consider activation checkpointing when:

  • You’re training very deep networks with many layers
  • Your model is generating “CUDA out of memory” errors
  • You need to increase batch size for better training dynamics
  • You have computation capacity to spare but limited GPU memory
  • You’re working with transformer-based architectures

Avoid activation checkpointing when:

  • Training speed is the primary concern and memory is abundant
  • Your model is already computationally bottlenecked
  • The model is small enough to fit comfortably in memory

Conclusion

Activation checkpointing is an important algorithmic innovation enabling the current era of LLMs and deep neural networks. By trading computation for memory in a controlled manner, it has extended the frontier of what’s possible in DL research and applications.

As models continue to grow in size and complexity, expect checkpointing techniques to evolve further, with more sophisticated strategies for balancing memory usage and computational efficiency.

References and Further Reading

  • Chen et al., “GradientCheckpoint: Trading Computation for Memory in Neural Network Training”
  • Gruslys et al., “Memory-Efficient Backpropagation Through Time”
  • Jain et al., “Checkmate: Breaking the Memory Wall with Optimal Tensor Rematerialization”
  • PyTorch Documentation: torch.utils.checkpoint
  • TensorFlow Documentation: tf.recompute_grad
Posted in AI / ML, LLM Intermediate
Write a comment