Unboxing LLMs > loading...

November 15, 2023

Mastering LLM Inference: From Mechanics to Optimization

Mastering LLM Inference: From Mechanics to Optimization

1. Introduction: Why Inference Matters

While model training often captures headlines, inference is where AI actually delivers value. Once a Large Language Model (LLM) is trained, we face a critical challenge: how do we generate text efficiently, at scale, and with the right balance of creativity, coherence, and control?

Inference encompasses several interconnected aspects:

  1. Decoding Strategies: The algorithms that determine how we select each successive token
  2. Parameter Control: How we tune settings like temperature to manage the trade-off between deterministic and creative outputs
  3. Hardware Deployment: Techniques to run massive models across distributed infrastructure
  4. Optimization Techniques: Methods like quantization and pruning that make models faster and more efficient

Understanding these mechanics is essential for anyone deploying LLMs in production. This article provides a comprehensive guide to help you configure, optimize, and scale LLM inference for your specific needs.


2. Decoding Strategies: The Art of Token Selection

2.1 Understanding Token-by-Token Generation

Modern LLMs use autoregressive generation – they produce text one token at a time, with each new token depending on all previous ones. This process can be mathematically represented as:

LaTeX: p(w_{t+1} \mid w_1, w_2, \dots, w_t)

This formula shows how we calculate the probability of each possible next token (𝑤_{t+1}), given all previously generated tokens in the sequence.

For context, tokens aren’t exactly words – they’re subword units that might represent partial words, full words, punctuation, or special symbols. For example, “understanding” might be split into “under” and “standing” as separate tokens.

2.2 Greedy Decoding: The Straightforward Approach

Greedy decoding selects the single most probable token at each step:

LaTeX: \hat{w}_{t+1} = \arg \max_i \, p(\text{token}_i \mid w_{1:t})
# Simple implementation of greedy decoding
def greedy_decode(model, input_ids, max_length):
    for _ in range(max_length):
        outputs = model(input_ids)
        next_token_id = outputs.logits[:, -1, :].argmax(dim=-1)
        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
    return input_ids

Advantages:
– Computationally efficient
– no need to sample or track multiple candidates
– Deterministic results for the same prompt – Simple implementation

Disadvantages:
– Often produces repetitive text
– Lacks creativity and diversity in responses
– Can get stuck in “probability traps” where it repeats patterns

Best for: Factual responses, structured outputs (like code or data formats), or when predictability is more important than creativity.

Flowchart Diagram

2.3 Temperature-Based Sampling: Controlling Creativity

2.3.1 The Temperature Parameter

Temperature (𝑇) modifies how the model’s probability distribution is sampled by scaling the logits before applying softmax:

LaTeX: p_i = \frac{\exp\left(\frac{z_i}{T}\right)}{\sum_j \exp\left(\frac{z_j}{T}\right)}

Where 𝑧_𝑖 represents the raw logit score for token 𝑖.

  • Low temperature (0.1-0.7): Creates more focused, deterministic, and “safe” outputs
  • Medium temperature (0.7-1.0): Balanced between creativity and coherence
  • High temperature (>1.0): Produces more diverse, unpredictable, and sometimes incoherent outputs
# Temperature sampling implementation
def temperature_sample(logits, temperature=0.7):
    if temperature == 0:
        # Handle division by zero - equivalent to greedy
        return torch.argmax(logits, dim=-1).unsqueeze(-1)
        
    # Apply temperature scaling
    logits = logits / temperature
    
    # Convert to probabilities
    probs = torch.softmax(logits, dim=-1)
    
    # Sample from the distribution
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token
Flowchart Diagram

2.3.2 Top-k Sampling: Limiting the Candidate Pool

Rather than considering all possible tokens (which might include highly improbable choices), top-k sampling restricts selection to only the k most likely tokens:

LaTeX: p(\text{token}_i) = \begin{cases} \frac{p_i}{\sum_{j \in K} p_j}, & i \in K, \\ 0, & \text{otherwise}, \end{cases}

Where K is the set of the k most probable tokens.

Flowchart Diagram
# Top-k sampling implementation
def top_k_sample(logits, k=50):
    # Find values and indices of the k largest entries
    v, idx = torch.topk(logits, k)
    
    # Zero out all values not in top k
    probs = torch.zeros_like(logits)
    probs.scatter_(1, idx, torch.softmax(v, dim=-1))
    
    # Sample from the truncated distribution
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

The k parameter controls diversity – higher values allow more varied outputs but might include less relevant tokens.

2.3.3 Top-p (Nucleus) Sampling: Dynamic Token Selection

Unlike top-k’s fixed cutoff, nucleus sampling (or top-p) dynamically selects the smallest set of tokens whose cumulative probability exceeds threshold p:

LaTeX: \sum_{i \in K} p_i \ge p
# Top-p (nucleus) sampling implementation
def top_p_sample(logits, p=0.9):
    # Sort logits in descending order
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > p
    # Shift the indices to the right to keep the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # Create a sparse distribution
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    probs = torch.softmax(logits, dim=-1)
    probs[indices_to_remove] = 0
    
    # Renormalize and sample
    probs = probs / probs.sum()
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token
Flowchart Diagram

Advantages over top-k:
– Adapts to the confidence level of the model at each step
– More context-appropriate diversity
– Better handling of both high and low entropy distributions

Industry practice: Top-p with p=0.9 and a moderate temperature (0.7-0.8) has become a standard configuration for many applications.

2.4 Beam Search: Exploring Multiple Paths

Rather than generating a single token sequence, beam search maintains B parallel sequences (the “beam width”):

  1. Start with a single sequence (the prompt)
  2. For each existing sequence, compute probabilities for all possible next tokens
  3. From all possible continuations, keep only the B sequences with highest overall probability
  4. Repeat until generation is complete
# Simplified beam search implementation
def beam_search(model, input_ids, beam_width=4, max_length=50):
    # Initialize with the prompt sequence
    sequences = [(input_ids, 0)]  # (sequence, score)
    
    for _ in range(max_length):
        candidates = []
        
        # Expand each current sequence
        for seq, score in sequences:
            outputs = model(seq)
            logits = outputs.logits[:, -1, :]
            
            # Get top-k probabilities and tokens for each sequence
            probs, next_tokens = torch.topk(torch.softmax(logits, dim=-1), k=beam_width)
            
            # Add all possible continuations to candidates
            for prob, token in zip(probs[0], next_tokens[0]):
                candidates.append((
                    torch.cat([seq, token.unsqueeze(0).unsqueeze(0)], dim=1),
                    score + torch.log(prob).item()
                ))
        
        # Select top sequences for next iteration
        sequences = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
    
    # Return the highest scored complete sequence
    return sequences[0][0]

Applications:
– Machine translation
– Text summarization
– Any task requiring precision in output

Flowchart Diagram

Limitations:
– Memory intensive
– scales with beam width
– Can produce overly generic outputs due to “beam search curse”
– Slower than single-path methods

2.5 Advanced Traversal Techniques

Beyond standard approaches, specialized applications might use:

  • Depth-First Search (DFS): Explores one branch fully before backtracking
  • Breadth-First Search (BFS): Examines all possibilities at each depth before moving deeper
  • Monte Carlo Tree Search (MCTS): Uses random sampling to balance exploration and exploitation

These techniques are particularly valuable for:
– Coding assistants that need to explore solution spaces
– Game-playing AI that must evaluate multiple move sequences
– Planning systems that require structured, multi-step reasoning


3. Practical Decoding Techniques and Considerations

3.1 Hybrid Approaches

Best results often come from combining multiple techniques:

# Combined top-k and top-p with temperature
def generate_text(model, prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.9):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :] / temperature
            
            # Apply top-k filtering
            top_k_logits, top_k_indices = torch.topk(logits, top_k)
            
            # Apply top-p filtering
            probs = torch.softmax(top_k_logits, dim=-1)
            cumulative_probs = torch.cumsum(probs, dim=-1)
            
            # Remove tokens exceeding the probability mass threshold
            mask = cumulative_probs < top_p
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = True
            
            # Filter and normalize
            top_k_logits = top_k_logits.masked_fill(~mask, -float('inf'))
            filtered_probs = torch.softmax(top_k_logits, dim=-1)
            
            # Sample
            next_token_index = torch.multinomial(filtered_probs, 1)
            next_token = top_k_indices[0, next_token_index[0]]
            
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
            
            # Check for stopping criteria
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(input_ids[0])

3.2 Preventing Repetition and Improving Quality

Several techniques help avoid common generation problems:

  • Repetition Penalties: Reduce the probability of tokens that have appeared recently
  • Frequency Penalties: Penalize tokens that appear frequently in the generated text
  • Presence Penalties: Discourage reusing tokens from the prompt
# Implementing repetition penalty
def apply_repetition_penalty(logits, input_ids, penalty=1.2):
    # Get unique token IDs and their positions
    token_ids = torch.unique(input_ids)
    
    # Apply penalty to logits of tokens that have appeared
    for token_id in token_ids:
        logits[0, token_id] /= penalty
        
    return logits

3.3 Controlled Generation with Stop Criteria

Define specific conditions to end generation:

  • Token-based stops: End on specific tokens like <|endoftext|> or \n\n
  • Pattern matching: Stop when certain text patterns appear
  • Maximum length: Set hard limits on output length
  • Semantic stops: More advanced systems might stop when the output satisfies certain criteria
# Example stop conditions
stop_strings = ["\n\n", "###", "END"]

# Check for stop conditions in generated text
def check_stop_conditions(text):
    for stop_string in stop_strings:
        if stop_string in text:
            # Truncate at the stop string
            return text[:text.find(stop_string)]
    return None  # No stop condition met

4. Scaling for Large Models: Sharding and Distribution

4.1 Understanding Sharding Necessity

Most state-of-the-art LLMs (from several billion to hundreds of billions of parameters) exceed the memory capacity of a single GPU. For instance:

Model SizeFP16 Memory FootprintGPUs Required (48GB each)
7B~14GB1
13B~26GB1
70B~140GB3+
175B~350GB8+

This necessitates distributing model parameters across multiple devices.

4.2 Key Sharding Strategies

4.2.1 Tensor Parallelism

Tensor Parallelism splits individual weight matrices across multiple GPUs:

Single Weight Matrix

This approach requires communication between GPUs during forward and backward passes, but enables efficient parallelism.

# Example tensor parallelism with DeepSpeed
import deepspeed
from transformers import AutoModelForCausalLM

# Initialize with tensor parallelism
model = AutoModelForCausalLM.from_pretrained("llama-13b")
ds_engine = deepspeed.initialize(
    model=model,
    config={
        "tensor_parallel": {
            "enabled": True,
            "tp_size": 2  # Split across 2 GPUs
        }
    }
)

4.2.2 Pipeline Parallelism

Pipeline Parallelism splits the model by layers:

This approach requires less communication but can lead to GPU idle time while waiting for activations from previous layers.

# Using Hugging Face Accelerate for pipeline parallelism
from accelerate import Accelerator
from transformers import AutoModelForCausalLM

accelerator = Accelerator(
    split_batches=True,
    gradient_accumulation_steps=4,
    pipeline_parallel=True
)

model = AutoModelForCausalLM.from_pretrained("llama-70b")
model = accelerator.prepare_model(model)

4.2.3 Sequence Parallelism

Sequence Parallelism splits along the sequence length dimension, processing different parts of the sequence on different GPUs:

Flowchart Diagram

This is especially useful for processing very long contexts.

FrameworkKey FeaturesBest For
DeepSpeedZero Redundancy Optimizer, tensor/pipeline parallelismProduction deployment, training
Megatron-LMAdvanced parallelism strategiesMaximum efficiency for very large models
Hugging Face AccelerateEasy to use, good for moderate parallelismResearch, smaller deployments
vLLMPagedAttention for efficient inference, continuous batchingHigh-throughput inference
# Simple vLLM example for efficient inference
from vllm import LLM, SamplingParams

# Initialize the model with tensor parallelism
llm = LLM(
    model="meta-llama/Llama-2-70b-chat-hf",
    tensor_parallel_size=4,  # Use 4 GPUs
    gpu_memory_utilization=0.9
)

# Define sampling parameters
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=256
)

# Generate text with efficient batching
outputs = llm.generate(
    ["Explain quantum computing", "Write a short poem about AI"],
    sampling_params
)

5. Optimizing Inference with ONNX and Specialized Runtimes

5.1 ONNX Format and Exports

The Open Neural Network Exchange (ONNX) provides a standardized format for neural networks that enables deployment across different hardware and software platforms.

# Exporting a Transformer model to ONNX
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Create dummy inputs (adjust sequence length as needed)
dummy_input = tokenizer("Hello, how are you?", return_tensors="pt").input_ids

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    f"{model_id}.onnx",
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "logits": {0: "batch_size", 1: "sequence_length"}
    },
    opset_version=15
)

5.2 Optimized Inference Runtimes

Several specialized runtimes can significantly improve inference speed:

RuntimeKey BenefitsTypical Speedup
ONNX RuntimeGraph optimizations, cross-platform1.2-2x
TensorRTNVIDIA GPU optimization, INT8/FP16 acceleration2-5x
OpenVINOIntel CPU/GPU optimization1.5-3x on Intel hardware
CoreMLApple Silicon optimization2-4x on Apple devices
# Using ONNX Runtime for inference
import onnxruntime as ort
import numpy as np

# Initialize ONNX Runtime session with optimization level
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# Create inference session
session = ort.InferenceSession(
    f"{model_id}.onnx",
    session_options,
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

# Run inference
input_ids = tokenizer("Continue this: Once upon a time", return_tensors="pt").input_ids
ort_inputs = {session.get_inputs()[0].name: input_ids.numpy()}
ort_outputs = session.run(None, ort_inputs)

5.3 Benefits and Limitations

Key Benefits:
Performance: Significantly faster inference with specialized kernels
Cross-platform: Deploy to various devices without framework dependencies
Memory optimization: Reduced memory footprint with optimized graph execution

Limitations:
Flexibility: Some custom operations may not be supported
Development overhead: Requires additional testing and validation
Potential precision loss: Some optimizations trade accuracy for speed


6. Advanced Optimization Techniques

6.1 Weight Pruning and Sparsity

Pruning removes less important weights from the model, making it sparser:

# Simple magnitude-based pruning
def prune_model(model, pruning_threshold=0.01):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Create a mask based on weight magnitude
            mask = torch.abs(module.weight.data) > pruning_threshold
            # Apply the mask
            module.weight.data *= mask
    return model

The benefits of pruning depend on hardware support for sparse operations:

  • Modern NVIDIA GPUs (A100, H100) can accelerate sparse matrix operations
  • Specialized hardware like Cerebras and Graphcore have native sparsity support
  • Without hardware acceleration, pruning may reduce model size but not improve speed

6.2 Quantization Techniques

6.2.1 Post-Training Quantization

GPTQ (Generative Pre-Training Quantization) and similar techniques convert trained models to lower precision after training:

Flowchart Diagram
# Example using GPTQ with AutoGPTQ
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

model_id = "meta-llama/Llama-2-7b"
quantized_model_dir = "llama-2-7b-4bit"

# Define quantization configuration
quantize_config = BaseQuantizeConfig(
    bits=4,  # Quantize to 4-bit
    group_size=128,  # Group size for quantization
    desc_act=False  # Disable activation quantization
)

# Load and quantize model
model = AutoGPTQForCausalLM.from_pretrained(model_id)
model.quantize(quantize_config)

# Save quantized model
model.save_quantized(quantized_model_dir)

6.2.2 Quantization Formats and Precision Levels

FormatSize ReductionQuality ImpactHardware Support
FP162x (vs FP32)MinimalExcellent (all modern GPUs)
BF162x (vs FP32)LowGood (A100, H100, TPUs)
INT84x (vs FP32)Low-ModerateVery good (most GPUs)
INT48x (vs FP32)Moderate-HighLimited (specialized hardware)
NF48x (vs FP32)ModerateEmerging support

6.2.3 Mixed Precision Strategies

Not all parts of a model need the same precision:

# Conceptual example of mixed precision
def apply_mixed_precision(model):
    # Key attention matrices stay in higher precision
    for layer in model.layers:
        # 8-bit for most weights
        layer.mlp = quantize_to_int8(layer.mlp)
        
        # Keep attention in 16-bit for better quality
        layer.attention.query = quantize_to_fp16(layer.attention.query)
        layer.attention.key = quantize_to_fp16(layer.attention.key)
        layer.attention.value = quantize_to_fp16(layer.attention.value)
    
    return model

6.3 KV Cache Optimization

The key-value (KV) cache stores computed attention values to avoid recomputation:

Flowchart Diagram
# Pseudocode for KV cache implementation
def generate_with_kv_cache(model, input_ids, max_length):
    # Initialize empty KV cache
    batch_size = input_ids.shape[0]
    num_layers = len(model.layers)
    kv_cache = [(None, None) for _ in range(num_layers)]
    
    # Initial forward pass to fill cache
    outputs = model(input_ids, use_cache=True, past_key_values=None)
    logits = outputs.logits
    kv_cache = outputs.past_key_values
    
    # Generate tokens using cache
    for _ in range(max_length):
        # Get next token
        next_token = sample_token(logits[:, -1, :])
        input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
        
        # Use KV cache for next token prediction
        outputs = model(next_token.unsqueeze(1), use_cache=True, past_key_values=kv_cache)
        logits = outputs.logits
        kv_cache = outputs.past_key_values
    
    return input_ids

Memory considerations:
– KV cache size grows linearly with sequence length
– For a 70B parameter model with 32K context, the KV cache can exceed 80GB in FP16
– Techniques like sliding windows, cache pruning, and flash attention reduce this overhead


7. Complete Inference Pipeline: Putting It All Together

A production-ready LLM inference system typically combines multiple techniques:

# Comprehensive inference pipeline pseudocode
def optimized_inference_pipeline(prompt, model_id, max_length=100):
    # 1. Load and optimize model
    model = load_quantized_model(model_id, bits=4)  # Load 4-bit quantized
    model = shard_model(model, num_gpus=4)  # Distribute across GPUs
    
    # 2. Tokenize input
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    # 3. Configure generation parameters
    gen_config = {
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50,
        "repetition_penalty": 1.2,
        "max_length": max_length,
        "stop_tokens": [tokenizer.eos_token_id]
    }
    
    # 4. Generate with KV cache and continuous batching
    output_ids = generate_with_kv_cache_and_prefill(
        model, 
        input_ids, 
        generation_config=gen_config
    )
    
    # 5. Post-process output
    output_text = tokenizer.decode(output_ids[0])
    output_text = apply_stop_criteria(output_text)
    
    return output_text

7.1 Inference Server Architecture

Modern deployments use specialized services for LLM inference:

 

7.2 Advanced Techniques in Production

  • Continuous batching: Process requests as they arrive without waiting for batch completion
  • Request routing: Direct requests to appropriate models based on complexity
  • Adaptive scaling: Dynamically adjust resources based on load
  • Prompt caching: Cache results for common prompts
  • Progressive generation: Stream tokens to clients as they’re generated

8. Emerging Research and Future Directions

8.1 Specialized Hardware Accelerators

New hardware is being developed specifically for LLM inference:

  • NVIDIA Tensor Cores: Specialized for matrix operations in transformers
  • Google TPUs: Custom ASICs optimized for matrix multiplication
  • Cerebras CS-2: Wafer-scale engine with massive parallelism
  • Groq LPU: Linear Processing Unit designed for deterministic, high-throughput inference

8.2 Algorithmic Innovations

  • Speculative decoding: Generate multiple tokens in parallel with a smaller “draft” model
  • Sparse attention mechanisms: Reduce computation by focusing only on relevant tokens
  • Multi-query attention: Reuse keys and values across attention heads
  • Mixture-of-experts: Activate only relevant parts of the model for each input

8.3 Hybrid CPU/GPU Approaches

Not all inference needs to happen on GPUs:

# Example of CPU offloading with Hugging Face
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "llama-13b",
    device_map="auto",  # Automatically distribute across available devices
    offload_folder="offload",  # Directory for CPU offloading
    offload_state_dict=True  # Enable CPU offloading
)

# Model will use both GPU and CPU memory

9. Conclusions and Best Practices

Efficient LLM inference requires a holistic approach combining:

  1. Thoughtful decoding strategies: Balance between quality and speed
  2. Hardware-appropriate optimizations: Match techniques to your deployment environment
  3. Continuous monitoring and adaptation: Profile performance and adjust as needed

Remember that the best approach depends on your specific needs:

PriorityRecommended Techniques
Speed/CostHeavy quantization, caching, minimal sampling
QualityHigher precision, carefully tuned sampling
FlexibilityModular architecture with parameter controls

By understanding the mechanics and trade-offs of LLM inference, you can build systems that deliver powerful AI capabilities while meeting your performance and resource constraints.


10. Resources and Further Reading

Tools and Libraries

Papers and Articles

Benchmarks

Posted in AI / ML, LLM Advanced
Write a comment