Unboxing LLMs > loading...

November 29, 2023

Mastering GPU Memory Constraints for Large-Scale Model Training

1. Introduction

The ambition driving deep neural networks—particularly the behemoths we call Transformers—inevitably slams into the hard limits of physics and silicon. These models, swelling with billions of parameters, possess a voracious appetite for GPU memory. Even on the ostensibly high-end battlegrounds like NVIDIA A100 (80GB) or H100 GPUs, memory is often the primary bottleneck, the rude awakening when trying to wrangle massive models and the statistically satisfying, but physically demanding, large batch sizes required for stable training.

Let’s ground this in reality:

  • GPT-3, at 175 billion parameters demands something like 700GB just for the model parameters and the associated optimizer cruft. Forget fitting that on a single card, or even a modest cluster.
  • Even a relatively commonplace BERT-Large (340M parameters) needs ~3.6GB for its weights alone. Factor in an optimizer like AdamW, and the memory footprint easily surges past 10GB before you’ve even considered the data itself.

We aren’t entirely helpless against this onslaught. A toolkit of tactics exists to mitigate these memory pressures, essentially finding clever ways to sidestep the raw physical limitations:

  1. Gradient Accumulation: The art of pretending you have a bigger batch size than your memory can actually hold at once.
  2. Low Memory Optimization (LOMO): A more aggressive strategy questioning the necessity of hefty optimizer states, especially during fine-tuning.
  3. Mixed Precision Training: Trading numerical precision for memory, hoping the slight fuzziness doesn’t derail training.
  4. Activation Checkpointing: A brute-force trade: burn more compute cycles to re-calculate activations later, freeing up memory now.

This piece dives into the first two—Gradient Accumulation and LOMO. We’ll dissect how they attack different facets of the memory problem and offer concrete implementations for those navigating the choppy waters of large model training on resource-constrained hardware.


2. Sources of GPU Memory Usage

To fight the beast, you must first understand its anatomy. Where does all that precious GPU VRAM actually go during training? Four main culprits consume your memory budget:

2.1 Parameters

  • Model Weights: The actual learned knowledge. In Transformers, this means embedding tables, the intricate machinery of multi-head attention, feed-forward layers, layer norms – the whole apparatus.
  • The scale is often difficult to internalize. We throw around terms like ‘billions of parameters’ casually, but the physical storage cost is substantial:
    • BERT-Large: 340M parameters ≈ 1.3GB (FP32)
    • GPT-2 Large: 1.5B parameters ≈ 6GB (FP32)
    • Llama 2 70B: 70B parameters ≈ 280GB (FP32) – just for the weights.

2.2 Gradients

  • Backpropagation isn’t free. To know how to update each parameter, you need to calculate its gradient. This means storing a shadow tensor, identical in size to your model weights, holding these gradient values (unless you’re doing exotic on-the-fly fusion).
  • For every parameter p, you need memory for \nabla p. Effectively, this doubles the parameter footprint during the backward pass.
  • Our BERT-Large example? That’s another 1.3GB consumed solely by gradients.

2.3 Optimizer States

  • This is where things often get truly painful. Optimizers aren’t just applying gradients; many maintain their own internal state to guide the learning process.
  • Adam and AdamW: The workhorses of modern deep learning. They keep two auxiliary tensors per parameter (first-moment and second-moment estimates). This doesn’t double the memory cost – it triples it beyond the parameters and gradients alone.
    • A 1B parameter model using Adam demands roughly:
      • 4GB (parameters, FP32)
      • 4GB (gradients, FP32)
      • 8GB (optimizer states, FP32)
      • Total: ~16GB before activations even enter the picture.
  • SGD: The venerable stochastic gradient descent. With momentum, it needs one extra tensor (1x parameter size). Plain SGD? No extra state, a spartan existence.

2.4 Activations

  • The intermediate results calculated during the forward pass. These need to be kept around because they’re essential inputs for calculating gradients during the backward pass. Their memory footprint is sensitive to:

    • Batch size: More samples processed in parallel means more activations stored simultaneously.
    • Sequence length / Resolution: Longer sequences in NLP or higher-resolution images in vision dramatically increase activation size.
    • Model depth: Deeper networks inherently generate more intermediate activations.
  • For large Transformers, especially with long contexts, activations can easily become the dominant memory hog, dwarfing even parameters and optimizer states.

    • Training BERT-Large (seq len 512, batch 32) can devour 30-45GB just for activations.
  • The scaling here is often brutal: linear with batch size, but potentially quadratic with sequence length in attention mechanisms.

2.5 Memory Usage Visualization

This diagram offers a conceptual breakdown of where the memory goes when training a large model with a typical Adam-style optimizer:

GPU Memory Usage Breakdown


3. Gradient Accumulation: Virtual Large Batches

3.1 How It Works

The standard training dance is: load batch, compute loss, backpropagate to get gradients, update weights. Repeat. Gradient Accumulation introduces a pause before the update step. It’s a simple, yet effective, trick:

  1. Decide on your desired effective batch size, B_{\textrm{eff}}. This is the batch size you wish you could run if memory were infinite.
  2. Divide this into k smaller micro-batches that do fit into memory.
  3. Process each micro-batch sequentially: forward pass, loss computation, backward pass to get gradients.
  4. Instead of updating weights immediately, add the gradients from the current micro-batch to a running sum of gradients.
  5. Only after processing all k micro-batches, perform a single optimizer step using the accumulated gradients. Then, zero out the accumulated gradients and start the next cycle.

You simulate the gradient statistics of a large batch without ever needing to hold that entire batch and its associated activations in memory simultaneously.

3.2 Benefits

  • Simulates Larger Batches: Critical for training stability in scenarios demanding large batches (e.g., specific optimizers like LAMB, or large-scale pretraining where noise reduction is paramount).
  • Preserves Training Dynamics: The resulting gradients are (mathematically, barring floating-point nuances) identical to processing the full effective batch at once. Training behavior remains largely consistent.
  • Avoids Peak Memory Spikes: Your peak memory usage is dictated by the micro-batch size, not the much larger effective batch size.
  • Plays Nice with Distributed Training: Can be easily combined with data parallelism across multiple GPUs, allowing for heterogeneity in device memory capacities.

3.3 Example Implementation

Here’s a conceptual PyTorch implementation illustrating the core logic:

import torch
# Assume loss_function is defined elsewhere

def train_with_gradient_accumulation(model, optimizer, dataloader,
                                    accum_steps=4, device="cuda", max_norm=1.0):
    model.train()
    model.to(device) # Ensure model is on the correct device

    running_loss = 0.0
    global_step = 0

    # Initialize gradients to zero before the loop
    optimizer.zero_grad()

    for step, batch in enumerate(dataloader):
        # Assuming batch is a tuple or dict of tensors
        # Move data to the target device
        if isinstance(batch, (tuple, list)):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
        elif isinstance(batch, dict):
            # Handle dictionary-based batches (common with Hugging Face datasets)
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            targets = batch['labels'].to(device)
        else:
            # Basic fallback if batch is just input tensors
            inputs = batch.to(device)
            targets = None # Adjust if targets are handled differently

        # Forward pass
        if isinstance(batch, dict):
             # Handle potential keyword arguments for model forward pass
            outputs = model(**inputs)
            # Adjust loss calculation if needed based on model output structure
            loss = loss_function(outputs.logits if hasattr(outputs, 'logits') else outputs, targets)
        elif targets is not None:
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
        else:
            # Handle cases where loss is computed inside the model (e.g., Hugging Face models)
            outputs = model(inputs)
            loss = outputs.loss if hasattr(outputs, 'loss') else loss_function(outputs, targets) # Placeholder


        # Normalize loss to average over the effective batch
        # Loss is accumulated over accum_steps, so average appropriately
        loss = loss / accum_steps

        # Backward pass computes gradients and accumulates them
        # (because optimizer.zero_grad() is only called after optimizer.step())
        loss.backward()

        running_loss += loss.item() # Track accumulated, normalized loss

        # Perform optimizer step only after accumulating gradients for accum_steps
        if (step + 1) % accum_steps == 0:
            # Optional: Gradient clipping applied to the accumulated gradients
            if max_norm is not None:
                 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

            # Update parameters based on accumulated gradients
            optimizer.step()

            # Reset gradients for the next accumulation cycle
            optimizer.zero_grad()

            # Logging and metrics calculation
            global_step += 1
            if global_step % 100 == 0: # Log every 100 effective steps
                # The running_loss already reflects the average loss per micro-batch within the accumulation cycle
                # To get loss per effective step, multiply by accum_steps
                effective_step_loss = running_loss * accum_steps
                print(f"Effective Step {global_step}: Average Loss = {effective_step_loss:.4f}")
                running_loss = 0.0 # Reset loss accumulator for the next logging interval

(Note: Ensure loss_function and data loading logic match your specific task and model.)

3.3.1 Gradient Accumulation Visualization

This sequence diagram contrasts standard updates with the gradient accumulation flow:

sequenceDiagram diagram

3.4 Calculating Memory Savings

Let’s revisit the 1B parameter model scenario, aiming for an effective batch size of 128 on a 24GB GPU, where processing a single example costs roughly 0.4GB in activations and temporary storage (excluding persistent model/optimizer states).

Without gradient accumulation:

  • Memory for batch 128 activations/temps: 128 samples * 0.4 GB/sample = 51.2GB. This alone exceeds the 24GB available, even before adding model/optimizer state memory. ❌

With gradient accumulation (e.g., micro-batch size 16):

  • Accumulation steps needed: k = 128 / 16 = 8 steps.
  • Peak memory for activations/temps per micro-batch: 16 samples * 0.4 GB/sample = 6.4GB.
  • This 6.4GB (plus model/optimizer states) likely fits within the 24GB budget. ✓

Gradient accumulation trades a bit of wall-clock time (due to sequential processing) for significantly lower peak memory usage.


4. Low Memory Optimization (LOMO)

4.1 Motivation and Background

Fine-tuning large pretrained models presents its own memory headaches. While you’re not training from scratch, you still need decent batch sizes for stable adaptation, and the sheer size of the base model dominates memory. Techniques like LoRA help by only training small adapters, but what if you need to fine-tune more parameters, or even the whole model?

This is where optimizers like Adam become a major liability. Their memory overhead (those first and second moment buffers) is substantial. The LOMO paper (LOMO: LOw-Memory Optimization, Lv et al., 2023 – Correction: Original article incorrectly cited Dettmers et al.) makes a provocative argument: for fine-tuning, maybe we don’t need the sophisticated adaptive machinery of Adam. The model is already in a good region of the loss landscape; perhaps plain old SGD is sufficient.

LOMO pushes this idea further by proposing not just using SGD (eliminating optimizer state memory), but also “fusing” the gradient calculation and parameter update into a single conceptual step, minimizing the time gradients need to live in memory.

4.2 Core Principles

  1. Eliminate Optimizer States: The most direct saving. By switching from Adam/AdamW to plain SGD, you immediately reclaim the memory used for the moment estimates (typically 2x the parameter size). This assumes that for fine-tuning, the benefits of adaptive gradients are less critical than during pretraining.
  2. Fused Gradient Computation and Update: Instead of calculating all gradients in a full backward pass and then applying the optimizer step, LOMO aims to compute the gradient for a parameter (or layer) and immediately apply the SGD update (param = param - lr * grad). Once updated, the gradient is discarded. This avoids holding the full set of gradients in memory simultaneously.
  3. Fine-Tuning Context: The core assumption is that pretrained models are less likely to be stuck in problematic saddle points or plateaus where Adam’s adaptivity shines. Simple SGD, possibly with a well-tuned learning rate schedule, can suffice for the relatively small adjustments needed during fine-tuning.

4.2.1 LOMO Process Visualization

This flowchart highlights the conceptual difference:

Standard Adam Optimization

The key is minimizing the resident time of gradients and eliminating the optimizer state buffers entirely.

4.3 Memory Comparison

Let’s consider a 7B parameter model (like Llama 2 7B) again, focusing on the persistent state memory (activations are separate and depend on batch/sequence length):

Component AdamW (FP32) SGD w/ Momentum (FP32) LOMO (SGD, FP32)
Parameters 28GB 28GB 28GB
Gradients (Peak) 28GB 28GB ~Small (Layer-wise)
Optimizer States 56GB 28GB 0GB
Subtotal (State) 112GB 84GB ~28GB + LayerGrad

Note: LOMO’s peak gradient memory is much smaller as only one layer’s (or block’s) gradients exist at a time before being used and discarded. The total state memory reduction compared to AdamW is dramatic.

4.4 Implementation Example

Achieving the true fused update requires deep integration with the autograd system, often via C++ extensions. A highly simplified Python sketch illustrating the concept (and not a performant or fully correct implementation) might look like this:

import torch
# Assume loss_function is defined elsewhere

# !!! WARNING: This is a conceptual illustration ONLY. !!!
# !!! It does NOT correctly handle the autograd graph for fused updates !!!
# !!! and is likely incorrect and inefficient. Use the official LOMO repo. !!!
def lomo_training_step_conceptual(model, inputs, targets, learning_rate=1e-5, device="cuda"):
    model.train()
    model.to(device)

    # Move data to device (simplified)
    inputs = inputs.to(device)
    targets = targets.to(device)

    # 1. Forward pass to compute loss
    outputs = model(inputs)
    loss = loss_function(outputs, targets) # Assume targets is not None here

    # 2. Conceptual "fused" backward pass and update
    # Clear any stale gradients first
    for param in model.parameters():
        if param.grad is not None:
            param.grad.zero_() # Or set to None

    # Trigger backward pass to compute gradients layer by layer (conceptually)
    loss.backward() # Standard backward computes all gradients

    # In a real implementation, the update would happen *during* backward.
    # Here, we simulate the immediate update *after* standard backward.
    with torch.no_grad():
        for param in model.parameters():
            if param.grad is not None:
                # Apply SGD update
                param.data.sub_(param.grad, alpha=learning_rate)
                # Free gradient memory immediately (in theory)
                param.grad = None # Or param.grad.zero_()

    return loss.item()

Crucial Caveat: This Python code is purely illustrative. Real LOMO requires custom autograd functions to fuse the backward pass and optimizer step correctly and efficiently. Trying to mimic it naively in Python as shown above will likely break the computation graph or be incredibly slow. Please refer to the official implementation: https://github.com/OpenLMLab/LOMO.

4.5 Limitations and Considerations

LOMO is a powerful tool, but it’s not a free lunch:

  1. Training Stability: Plain SGD can be more sensitive to learning rate choices and potentially less stable than Adam, especially if the fine-tuning task involves significant domain shifts.
  2. Convergence Speed: May require more training steps or careful learning rate scheduling to match the final performance achieved by Adam-based methods.
  3. Implementation Hurdle: Getting the fused operations right requires non-trivial engineering effort, often beyond pure Python.
  4. Best Use Case: Primarily designed and validated for fine-tuning scenarios. Its suitability for training models from scratch is less established.

5. Practical Considerations and Combined Strategies

5.1 Combining Techniques for Maximum Efficiency

Often, the most memory-constrained scenarios demand a multi-pronged attack. Combining these techniques is where the real magic (or desperate engineering) happens:

  1. Gradient Accumulation + LOMO: A potent combination for fine-tuning massive models. Use LOMO to drastically cut the state memory per micro-batch, and use gradient accumulation to simulate the large effective batch size needed for stable fine-tuning. This pairing can make fine-tuning 70B+ models feasible on hardware that would otherwise be hopelessly inadequate.
  2. Mixed Precision + Gradient Accumulation: A standard and highly effective combo. Mixed precision (FP16/BF16) roughly halves memory usage across the board (parameters, gradients, optimizer states, activations). Adding gradient accumulation on top allows you to push effective batch sizes much higher within the memory saved by mixed precision. PyTorch’s torch.cuda.amp makes this relatively straightforward:
import torch
from torch.cuda.amp import autocast, GradScaler
# Assume loss_function, model, optimizer, dataloader are defined

# Initialize GradScaler for managing mixed precision scaling
scaler = GradScaler()

def train_with_mixed_precision_and_grad_accum(model, optimizer, dataloader,
                                             accum_steps=4, device="cuda", max_norm=1.0):
    model.train()
    model.to(device)
    optimizer.zero_grad()

    running_loss = 0.0
    global_step = 0

    for step, batch in enumerate(dataloader):
        # Move data to device (handle tuple/dict as before)
        if isinstance(batch, (tuple, list)):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
        elif isinstance(batch, dict):
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            targets = batch['labels'].to(device)
        else: # Fallback
            inputs = batch.to(device)
            targets = None

        # Autocast context manager enables mixed precision for forward pass
        with autocast():
            if isinstance(batch, dict):
                 outputs = model(**inputs)
                 loss = loss_function(outputs.logits if hasattr(outputs, 'logits') else outputs, targets)
            elif targets is not None:
                 outputs = model(inputs)
                 loss = loss_function(outputs, targets)
            else: # Model computes loss internally
                 outputs = model(inputs)
                 loss = outputs.loss

            loss = loss / accum_steps # Normalize loss for accumulation

        # scaler.scale() scales the loss for gradient computation
        scaler.scale(loss).backward()

        running_loss += loss.item() # Track normalized loss

        if (step + 1) % accum_steps == 0:
            # scaler.unscale_() unscales gradients before clipping/optimizer step
            scaler.unscale_(optimizer)

            # Clip gradients (optional)
            if max_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

            # scaler.step() performs optimizer step, checking for inf/NaNs
            scaler.step(optimizer)

            # scaler.update() updates the scale factor for next iteration
            scaler.update()

            # Zero gradients for the next accumulation cycle
            optimizer.zero_grad()

            global_step += 1
            if global_step % 100 == 0:
                effective_step_loss = running_loss * accum_steps
                print(f"Effective Step {global_step}: Average Loss = {effective_step_loss:.4f}")
                running_loss = 0.0
  1. Activation Checkpointing + LOMO: Activation checkpointing directly tackles the activation memory bottleneck by recomputing them during the backward pass instead of storing them all. Combining this with LOMO’s state reduction allows tackling truly enormous models, especially deep Transformers where activation memory scales significantly with depth.

5.1.1 Combined Strategies Visualization

Think of applying these techniques as layers of defense against memory overflow:

Combined Memory Optimization Strategy

This staged approach lets you introduce complexity only as needed, starting with the easiest wins (mixed precision) and progressing to more involved strategies.

5.2 Real-World Case Studies

Abstract principles are fine, but seeing them in action clarifies their value:

Case Study 1: Fine-tuning a 7B Model on a Consumer GPU

  • Constraint: Single NVIDIA RTX 3090 (24GB VRAM).
  • Goal: Fine-tune Llama 2 7B for instruction following.
  • Solution:
    • Mixed Precision (BF16) to cut baseline memory roughly in half.
    • LOMO optimizer to eliminate Adam state (~56GB savings).
    • Gradient Accumulation (micro-batch size 1, accumulate 32 steps) to handle sequence length and effective batch size.
  • Outcome: Successful fine-tuning achieved, peak memory kept around ~15-18GB, well within the 24GB limit. Impossible without this combination.

Case Study 2: Training a >1B Parameter Vision Transformer

  • Constraint: Single NVIDIA A100 (40GB VRAM).
  • Goal: Train a 2B parameter ViT on high-resolution images.
  • Challenge: Activations explode with image size and model depth.
  • Solution:
    • Gradient Accumulation (micro-batch 2, accumulate 16 steps) for effective batch size 32.
    • Activation Checkpointing applied strategically to the Transformer blocks (e.g., every second block).
  • Outcome: Training became feasible, fitting a model ~3x larger than possible without checkpointing, albeit with increased training time due to recomputation.

5.3 Trade-offs and Decision Matrix

Choosing the right technique(s) involves balancing memory savings against potential impacts on speed and complexity:

Technique Memory Reduction Potential Training Speed Impact Implementation Difficulty Ideal Scenario
Gradient Accumulation Medium (Activations/Temps) Slightly Slower Low Need larger effective batch size
LOMO High (Optimizer State) Potentially Slower Medium-High Fine-tuning very large models, state is bottleneck
Mixed Precision Medium (Halves most state) Often Faster Low Almost always beneficial, start here
Activation Checkpointing High (Activations) Slower (Recomputation) Low-Medium Deep networks where activations dominate memory
All Combined Very High Slower than baseline High Pushing absolute hardware limits (e.g., 70B+ models)

6. Emerging Techniques and Future Directions

The arms race against memory limits continues. Gradient accumulation and LOMO are powerful, but they are part of a broader ecosystem of efficiency techniques:

6.1 Parameter-Efficient Fine-Tuning (PEFT)

Methods like LoRA, QLoRA (Quantized LoRA), AdaLoRA, etc., take a different philosophical approach: instead of optimizing the full model state, drastically reduce the state being optimized.

  • Freeze the vast majority of pretrained weights.
  • Train only small, add-on adapter modules or low-rank updates.
  • Often combined with 4-bit quantization (QLoRA) for further parameter memory reduction.
  • These PEFT methods are often used in conjunction with gradient accumulation and sometimes even LOMO-like optimizers for maximal efficiency.

6.2 Model Sharding and Offloading

When a single GPU simply isn’t enough, even with optimizations, distributed strategies become necessary:

  • ZeRO (Zero Redundancy Optimizer): A family of techniques (popularized by DeepSpeed) that shard optimizer states, gradients, and even parameters across multiple GPUs in a data-parallel group, ensuring no single GPU needs to hold everything.
  • CPU/NVMe Offloading: For parameters or optimizer states not immediately needed, ZeRO variants can offload them to CPU RAM or even fast NVMe storage, trading communication latency for massive memory capacity expansion.
  • These are complex system-level optimizations, often complementary to the single-GPU techniques discussed here.

6.3 Specialized Hardware Considerations

The optimal strategy can depend on the underlying hardware:

  • Multi-GPU Clusters: Naturally lend themselves to data parallelism combined with gradient accumulation and potentially ZeRO.
  • Heterogeneous Systems (CPU+GPU): Can exploit CPU offloading more effectively.
  • Apple Silicon (M1/M2/M3): The unified memory architecture changes the game, reducing the penalty for “offloading” as CPU and GPU RAM are the same pool. Custom optimization strategies leveraging this are emerging.

7. Conclusion

Let’s be clear: GPU memory is a fundamental battleground in the quest to build and train ever-larger deep learning models. The staggering parameter counts of modern Transformers impose brutal constraints. Understanding precisely where memory evaporates—parameters, gradients, optimizer states, activations—is the first step towards intelligent mitigation.

Techniques like Gradient Accumulation offer a pragmatic way to achieve the statistical benefits of large batches without paying the full memory price upfront, acting as a crucial lever when batch size dynamics matter. For the specific, yet common, challenge of fine-tuning these behemoths, LOMO presents a more radical, opinionated approach – shedding the memory-hungry baggage of adaptive optimizers by betting that simple SGD suffices when you’re already close to a good solution.

In the trenches, wrestling with real-world hardware limits, the most effective solutions rarely rely on a single trick. Combining these strategies—layering mixed precision, gradient accumulation, perhaps LOMO or activation checkpointing—becomes necessary engineering triage. It’s about finding the least painful trade-offs between memory footprint, computational overhead, and implementation complexity to get the job done.

Mastering these memory optimization techniques becomes a strategic necessity. They are the tools that democratize access, allowing researchers and engineers without nation-state compute budgets to engage with models that would otherwise remain theoretical curiosities. As models inevitably continue their climb towards astronomical scale, the ingenuity applied to managing these finite memory resources will be just as critical as the architectural breakthroughs themselves. The ghost in the machine might be getting smarter, but it still needs space to think.

Posted in AI / ML, LLM Advanced