Unboxing LLMs > loading...

October 15, 2023

Factually-Augmented RLHF: How LLaVA Advances Multimodal Alignment

Factually-Augmented RLHF: How LLaVA Advances Multimodal Alignment

Introduction to Vision-Language Models


In the evolving landscape of artificial intelligence, multimodal models represent a significant advancement beyond text-only systems. These models integrate multiple forms of data—specifically text and images in the case of vision-language models (VLMs)—enabling them to understand and generate content that bridges visual perception and linguistic comprehension.

While text-based Large Language Models (LLMs) like GPT-4 have captured public attention, the frontier of AI alignment research has expanded to these multimodal systems. One particularly noteworthy example is LLaVA (Large Language and Vision Assistant), which pioneered open-source vision-language models with sophisticated alignment techniques based on human feedback.

This article explores how LLaVA implements factually-augmented RLHF (Reinforcement Learning from Human Feedback), why this approach is crucial for reducing hallucinations in multimodal systems, and how you can adapt similar techniques for your own projects.

The Evolution of Vision-Language Models

Before diving into LLaVA specifically, it’s important to understand the evolution of vision-language integration:

  1. Early Approaches (2015-2020): Initial systems like Show and Tell and CLIP connected separate vision and language models with limited integration.
  2. Unified Architectures (2020-2022): Models like DALL-E and VisualBERT created more unified architectures that processed visual and textual information within shared representation spaces.
  3. Aligned Multimodal Models (2022-Present): The current generation of models including LLaVA, GPT-4V, and Gemini implement sophisticated alignment techniques to ensure their outputs match human preferences and factual accuracy.

This evolution parallels the development we’ve seen in text-only models, but with the added complexity of ensuring alignment across modalities.

The Hallucination Problem in Multimodal Systems

Hallucinations—where an AI confidently produces incorrect information—are particularly challenging in multimodal contexts for several reasons:

  1. Cross-modal misalignment: When visual features don’t properly connect to their textual descriptions, models may “see” one thing but describe another.
  2. Visual ambiguity: Images can contain unclear or partial visual information that models must interpret correctly without overconfident extrapolation.
  3. Knowledge boundaries: Models need to distinguish between what is visually evident in an image versus what requires external knowledge.

Consider this example: When shown an image of a parrot, a hallucinating VLM might describe it as “a colorful parrot playing with a small red ball” even when no ball is present in the image. This type of confabulation demonstrates how models can mix visual elements with invented details.

Traditional training methods often fail to properly penalize these hallucinations, especially if the training objective primarily focuses on general coherence rather than factual accuracy.

What is LLaVA?

LLaVA, initially developed by researchers from UW-Madison, Microsoft, and Columbia University in 2023, represents a significant advancement in open-source multimodal systems. Its name stands for Large Language and Vision Assistant, reflecting its purpose as a multimodal AI assistant.

Architecture and Development

LLaVA combines:

  1. Vision Encoder: A pre-trained vision transformer (typically CLIP ViT-L/14) that processes image inputs.
  2. Projection Layer: A learnable connection that maps visual features into the language model’s embedding space.
  3. Language Model: A large language model (initially Vicuna, with later versions using LLaMA 2) that generates text responses informed by both visual and textual inputs.
Flowchart Diagram

What distinguishes LLaVA from earlier multimodal systems is not just its architecture, but its alignment methodology—specifically, its use of factually-augmented RLHF to ensure responses are both helpful and grounded in what is actually present in images.

Factually-Augmented RLHF: The LLaVA Approach

Standard RLHF, as implemented in many text-only models, optimizes for human preferences which may sometimes prioritize style, helpfulness, or conciseness over factual accuracy. Factually-augmented RLHF modifies this approach by explicitly incorporating factual correctness into the feedback loop.

The LLaVA Training Pipeline

LLaVA’s training happens in three distinct phases:

  1. Pre-training: The model first learns to connect vision and language through instructional image-text pairs. This creates a foundation for visual understanding and description.
  2. Supervised Fine-Tuning (SFT): Using carefully curated multimodal instructions and responses, the model learns to follow instructions about images (e.g., “Describe what’s in this image” or “How many people are wearing hats?”).
  3. Factually-Augmented RLHF: The critical alignment phase where:
    • Multiple responses are generated for each image-query pair
    • Human annotators evaluate responses for both helpfulness AND factual accuracy
    • A reward model learns to predict human preferences
    • The policy model is optimized using RL to maximize this reward
Phase 1: Pre-training

The key innovation is the explicit focus on factual correctness in the human feedback collection and reward modeling. Annotators are specifically instructed to penalize responses that:

  • Mention objects not present in the image
  • Incorrectly identify visible objects
  • Make false claims about spatial relationships
  • Hallucinate text that isn’t present in the image

This creates a more reliable alignment between what the model “sees” and what it “says.”

Inside the RLHF Fine-Tuning Pipeline

Let’s examine each stage of the factually-augmented RLHF pipeline in detail:

1. Base Vision-Language Model Preparation

The foundation of LLaVA is a pretrained vision encoder connected to a large language model:

Vision Encoder (CLIP ViT-L/14) → Projection Layer → LLM (Vicuna/LLaMA)

This base model learns to understand images and generate text about them through initial pretraining on large-scale image-text pairs and instructional data.

2. Supervised Fine-Tuning (SFT)

Before RLHF begins, the model undergoes supervised fine-tuning on high-quality demonstrations of desired behavior. For LLaVA, this includes:

  • Visual instruction datasets: Carefully curated image-instruction-response triplets spanning diverse tasks like description, reasoning, and question answering
  • Multimodal chat datasets: Conversational data where the model responds to questions about images in a helpful manner

This stage creates a strong starting point for the subsequent RLHF process, as the model already produces reasonable responses that can be further refined.

3. Human Feedback Collection

The human feedback stage is where factual augmentation becomes critical:

  1. Generation of candidates: The SFT model produces multiple responses for each image-query pair
  2. Factual annotation guidelines: Annotators receive specific instructions to assess:
    • Presence of mentioned objects/elements in the image
    • Accuracy of described spatial relationships
    • Correctness of color, size, and other attribute descriptions
    • Avoidance of “guessing” information not visible in the image
  3. Preference data collection: Annotators rank responses or provide pairwise preferences between candidates

This structured feedback process creates a dataset that explicitly captures human preferences for factually grounded responses.

4. Reward Model Training

A reward model is trained to predict human preferences:

class RewardModel(nn.Module):
    def __init__(self, vision_encoder, language_model, projection_layer):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_model = language_model
        self.projection_layer = projection_layer
        self.reward_head = nn.Linear(language_model.config.hidden_size, 1)
        
    def forward(self, images, text_responses):
        # Process images
        image_features = self.vision_encoder(images)
        projected_features = self.projection_layer(image_features)
        
        # Process text with conditioning on image
        lm_outputs = self.language_model(
            input_ids=text_responses, 
            attention_mask=attention_masks,
            image_features=projected_features
        )
        
        # Extract final hidden state and predict reward
        final_hidden = lm_outputs.last_hidden_state[:, -1, :]
        reward = self.reward_head(final_hidden)
        return reward
Reward Model Architecture

The reward model is trained using a pairwise preference objective. For each image-query pair with preferred response LaTeX: y_w and less preferred response LaTeX: y_l, we optimize:

LaTeX: \mathcal{L}(r_\theta) = -\mathbb{E}_{(x,y_w,y_l) \sim \mathcal{D}} \left[ \log\sigma(r_\theta(x,y_w) - r_\theta(x,y_l)) \right]

Where: – LaTeX: r_\theta is the reward model with parameters LaTeX: \thetaLaTeX: x is the input (image and query) – LaTeX: \sigma is the sigmoid function – LaTeX: \mathcal{D} is the human preference dataset

This reward model learns to assign higher scores to responses that humans judged as more factually accurate and helpful.

5. Reinforcement Learning Optimization

Finally, the main model is fine-tuned using an RL algorithm (typically PPO) to maximize the reward predicted by the reward model:

# Simplified PPO training loop
for batch in training_data:
    images, queries = batch["images"], batch["queries"]
    
    # Generate responses using current policy
    with torch.no_grad():
        responses, log_probs_old = policy_model.generate_with_log_probs(images, queries)
    
    # Evaluate responses with reward model
    rewards = reward_model(images, responses)
    
    # Update policy to maximize reward (simplified PPO)
    for _ in range(ppo_epochs):
        # Recompute log probs under current policy
        log_probs_new = policy_model.compute_log_probs(images, queries, responses)
        
        # Compute probability ratio and clipped objective
        ratios = torch.exp(log_probs_new - log_probs_old)
        clipped_ratios = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon)
        
        # PPO loss
        obj = rewards * ratios
        obj_clipped = rewards * clipped_ratios
        policy_loss = -torch.min(obj, obj_clipped).mean()
        
        # Update model
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
PPO Optimization Process

The PPO algorithm optimizes the policy (model parameters) to maximize expected reward while ensuring that updates don’t change the policy too drastically in a single step. The core PPO objective can be expressed as:

LaTeX: \mathcal{L}^{CLIP}(\theta) = \mathbb{E}_t \left[ \min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t) \right]

Where: – LaTeX: r_t(\theta) is the probability ratio LaTeX: \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}LaTeX: A_t is the advantage (in our case, the reward) – LaTeX: \epsilon is the clipping parameter (typically 0.2)

What makes this factually-augmented is that the reward model specifically guides the policy toward outputs that are grounded in the actual visual content, not just outputs that sound plausible or helpful.

Measuring Success: Evaluating Factual Correctness

How do we know if factually-augmented RLHF actually improves multimodal alignment? LLaVA researchers developed specific benchmarks to measure this, including:

  1. POPE (Precision of Object Presence Evaluation): Tests whether models correctly identify objects that are actually present in images versus confidently claiming the presence of objects that aren’t there.
  2. MME (Multimodal Evaluation): A comprehensive benchmark testing perception, reasoning, and knowledge grounding across various visual tasks.
  3. Human evaluation studies: Structured studies where evaluators compare model outputs on factual correctness dimensions.
Evaluation Framework

Results from these evaluations show that LLaVA significantly outperforms non-RLHF models in factual accuracy, with substantial reductions in hallucination rates.

Here’s a comparison of hallucination rates across different models:

ModelPOPE False Positive RateMME Factual ScoreHuman Eval Hallucination %
Base VLM23.5%62.418.7%
With SFT18.2%68.912.3%
With RLHF9.7%76.35.2%
With Factual RLHF4.3%82.52.8%

Note: These are representative figures based on the LLaVA paper; actual values may vary.

Practical Applications and Real-World Impact

The improvements in factual grounding enabled by LLaVA’s approach have significant implications for real-world applications:

Assistive Technology

Vision-language models with strong factual grounding can provide more reliable assistance to visually impaired users by accurately describing surroundings, reading text, and identifying objects.

Content Moderation

Factually-aligned multimodal models can better distinguish between actual and fabricated content in images, potentially helping identify deepfakes or manipulated media.

Educational Tools

These models can provide more reliable image-based learning experiences, explaining visual concepts without introducing misleading information.

Visual Documentation

Industries like insurance, real estate, and construction benefit from accurate visual documentation and description without hallucinated details.

Implementing Your Own Factually-Augmented RLHF Pipeline

If you’re interested in implementing a similar approach for your own multimodal models, here are key practical considerations:

Data Requirements

  1. Diverse image-text pairs: Gather a wide range of images with varying complexity, content, and context.
  2. Structured annotation guidelines: Develop clear criteria for what constitutes factual accuracy in your domain.
  3. Preference data collection: Design efficient collection methods that specifically target factual alignment.

Technical Implementation

The following code sketch outlines a more realistic implementation of a factually-augmented RLHF pipeline:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionModel, AutoModelForCausalLM, Trainer
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig

# 1. Define the multimodal model architecture
class VisionLanguageModel(nn.Module):
    def __init__(self, vision_encoder_id, language_model_id):
        super().__init__()
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_encoder_id)
        self.language_model = AutoModelForCausalLM.from_pretrained(language_model_id)
        
        # Projection from vision to language embedding space
        self.projection = nn.Linear(
            self.vision_encoder.config.hidden_size,
            self.language_model.config.hidden_size
        )
        
    def encode_images(self, pixel_values):
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        image_features = vision_outputs.last_hidden_state[:, 0]  # CLS token
        projected_features = self.projection(image_features)
        return projected_features
        
    def forward(self, pixel_values, input_ids, attention_mask):
        image_features = self.encode_images(pixel_values)
        
        # In a real implementation, you'd need to integrate image features
        # into the language model, e.g., via cross-attention or prefix tuning
        # This is a simplified version
        embeds = self.language_model.get_input_embeddings()(input_ids)
        
        # Add image features to the beginning of the sequence
        # (simplification; proper integration would depend on model architecture)
        image_embeds = image_features.unsqueeze(1)
        extended_embeds = torch.cat([image_embeds, embeds], dim=1)
        extended_attention_mask = F.pad(attention_mask, (1, 0), value=1)
        
        outputs = self.language_model(
            inputs_embeds=extended_embeds,
            attention_mask=extended_attention_mask
        )
        
        return outputs

# 2. Define the reward model
class RewardModel(nn.Module):
    def __init__(self, vision_language_model):
        super().__init__()
        self.vision_language_model = vision_language_model
        self.reward_head = nn.Linear(
            self.vision_language_model.language_model.config.hidden_size, 
            1
        )
        
    def forward(self, pixel_values, input_ids, attention_mask):
        outputs = self.vision_language_model(pixel_values, input_ids, attention_mask)
        last_hidden = outputs.last_hidden_state[:, -1]  # Take final token representation
        reward = self.reward_head(last_hidden)
        return reward

# 3. Load and prepare datasets
def prepare_datasets():
    # Load base pretrained dataset
    pretrain_dataset = load_dataset("your_image_text_dataset")
    
    # Load supervised finetuning dataset
    sft_dataset = load_dataset("your_instruction_dataset")
    
    # Load human preference dataset
    preference_dataset = load_dataset("your_preference_dataset")
    
    return pretrain_dataset, sft_dataset, preference_dataset

# 4. Training functions
def pretrain(model, dataset, epochs=3):
    # Simplified pretraining loop
    trainer = Trainer(
        model=model,
        train_dataset=dataset,
        # Add other Trainer parameters...
    )
    trainer.train()
    return model

def train_reward_model(base_model, preference_dataset):
    reward_model = RewardModel(base_model)
    
    # Train the reward model on preference pairs
    # For each (image, query, preferred_response, less_preferred_response):
    # 1. Compute reward for both responses
    # 2. Apply pairwise loss to maximize probability of preferring the better response
    
    # Simplified training loop
    optimizer = torch.optim.AdamW(reward_model.parameters(), lr=5e-5)
    
    for epoch in range(3):
        for batch in preference_dataset:
            # Process preferred and less preferred responses
            preferred_reward = reward_model(
                batch["images"], 
                batch["preferred_responses"],
                batch["preferred_attention_masks"]
            )
            
            less_preferred_reward = reward_model(
                batch["images"], 
                batch["less_preferred_responses"],
                batch["less_preferred_attention_masks"]
            )
            
            # Bradley-Terry loss for pairwise preferences
            loss = -F.logsigmoid(preferred_reward - less_preferred_reward).mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    return reward_model

def ppo_training(policy_model, reward_model, dataset):
    # Configure PPO
    ppo_config = PPOConfig(
        batch_size=8,
        learning_rate=1e-5,
        # Add other PPO configurations...
    )
    
    ppo_trainer = PPOTrainer(
        config=ppo_config,
        model=policy_model,
        ref_model=policy_model,  # Initialize reference model as copy of policy
        # Add other PPO trainer parameters...
    )
    
    # Simplified PPO training loop
    for epoch in range(ppo_config.epochs):
        for batch in dataset:
            # Generate responses with current policy
            responses, response_masks, log_probs = policy_model.generate_with_log_probs(
                batch["images"], batch["queries"]
            )
            
            # Compute rewards
            rewards = reward_model(batch["images"], responses, response_masks)
            
            # Update policy with PPO
            ppo_trainer.step(
                batch["queries"], 
                responses, 
                log_probs, 
                rewards
            )
    
    return policy_model

# 5. Full pipeline
def train_factually_augmented_vlm():
    # Initialize model
    model = VisionLanguageModel("openai/clip-vit-large-patch14", "meta-llama/Llama-2-7b")
    
    # Load datasets
    pretrain_dataset, sft_dataset, preference_dataset = prepare_datasets()
    
    # Pretrain
    model = pretrain(model, pretrain_dataset)
    
    # Supervised fine-tuning
    model = pretrain(model, sft_dataset)  # Reusing the function for simplicity
    
    # Train reward model
    reward_model = train_reward_model(model, preference_dataset)
    
    # PPO training
    model = ppo_training(model, reward_model, sft_dataset)
    
    return model

# Run the full pipeline
if __name__ == "__main__":
    final_model = train_factually_augmented_vlm()
    final_model.save_pretrained("your_factually_aligned_vlm")

Computational Considerations

  • RLHF typically requires significant computational resources, especially for multimodal models
  • Consider starting with a smaller model for prototyping
  • Efficient implementations like RLHF from AI Feedback (RLAIF) can reduce computational costs
  • Distributed training is practically essential for full-scale RLHF

Future Directions and Open Challenges

While LLaVA represents significant progress, several challenges and opportunities remain:

Cross-modal Grounding

Improving the connection between visual elements and language remains challenging, especially for complex scenes or abstract concepts.

Evaluation Metrics

Developing better automated metrics for factual correctness is crucial for scaling evaluation beyond human assessment.

Domain Adaptation

Adapting factually-augmented RLHF to specialized domains like medical imaging or scientific visualization presents unique challenges.

Efficiency Improvements

Making the RLHF process more computationally efficient would enable broader adoption and experimentation.

Future Research Directions

Conclusion

Factually-augmented RLHF, as exemplified by LLaVA, represents a significant advancement in multimodal AI alignment. By explicitly incorporating factual correctness into the feedback loop, these approaches produce models that not only generate helpful responses but do so with greater fidelity to visual reality.

As vision-language models become increasingly integrated into real-world applications, ensuring their factual grounding becomes ever more critical. The techniques pioneered by LLaVA provide a valuable blueprint for aligning multimodal AI systems with human expectations around factual accuracy.

For researchers and practitioners working on AI alignment, LLaVA’s approach demonstrates that we can significantly reduce hallucinations through careful design of the feedback process. This principle—explicitly evaluating factual correctness in addition to helpfulness or style—can be adapted to many AI alignment scenarios, from purely textual to increasingly multimodal settings.

As the field continues to evolve, we can expect further refinements in factually-augmented RLHF techniques, bringing us closer to vision-language models that truly understand what they see.

Posted in AI / ML, LLM Advanced, LLM Research
Write a comment