Unboxing LLMs > loading...

February 1, 2025

Teaching Small Gemma 2 Models to “Think” Like DeepSeek‑R1

Key Takeaways

  • We can graft DeepSeek‑R1’s chain-of-thought discipline onto a 2B-parameter Gemma 2 model, making sophisticated reasoning run on a single gaming GPU.
  • Two simple tokens, <think> and </think>, are all it takes to teach the model to wall off its internal monologue from its final answer-and crucially, when to shut up.
  • QLoRA delivers over 90% of a full fine-tune’s performance at a quarter of the memory cost, a testament to the brutal efficiency of low-rank adapters and 4-bit NF4 quantization.
  • First-principle guardrails-information bottlenecks, attention budgets-explain not just why this works, but predict its failure modes with unnerving accuracy.
  • A single pass through the data yields a surprisingly coherent thinker. A second epoch, or a higher LoRA rank, is the hammer you need to crush infinite loops and other emergent nonsense.

1 Motivation

The prevailing narrative pits gargantuan models like DeepSeek-R1 against everything else. They dominate reasoning benchmarks, but their 70B-parameter appetites demand compute budgets well beyond the hobbyist’s reach. The game here isn’t cloning that raw power; it’s distillation. My goal was to transplant DeepSeek-R1’s thinking pattern-its methodical chain-of-thought (CoT) process-into a pocket-sized model that can live in a notebook, a mobile app, or on an edge device without requiring a nation-state’s power grid. We’re not chasing benchmark supremacy; we’re capturing a specific, valuable cognitive skill.


2 Theoretical Background

2.1 Why Does Chain‑of‑Thought (CoT) Help?

CoT isn’t magic; it’s a form of cognitive offloading. It forces the model to treat its own generated text as an external memory scratchpad, sidestepping the need to compress complex reasoning into the finite space of its internal activations. It functions as an information bottleneck that:

  1. Forces the model to spread its cognitive load over a longer context, instead of trying to solve everything in a single, heroic forward pass.
  2. Reduces the search space. The model only needs to solve one incremental sub-problem at a time, rather than making an impossible leap from question to answer.
  3. Provides a much cleaner gradient signal during fine-tuning. We’re not just rewarding the right answer; we’re rewarding the right path.
  4. Improves knowledge retrieval by breaking down a monolithic problem into smaller queries that are easier to fetch from the model’s parametric memory.

Formally, given an input \mathbf{x}, a reasoning trace \mathbf{r}, and an answer y, CoT training minimizes the joint negative log‑likelihood:

\mathit{L}= - \left[\log P_{\theta}(\mathbf{r}\mid\mathbf{x}) + \log P_{\theta}(y\mid\mathbf{x},\mathbf{r})\right] .

The model learns to generate a plausible thought process and to trust that process when formulating its final answer.

2.2 Low‑Rank Adaptation in a Nutshell

LoRA is a clever hack that makes fine-tuning massive models tractable. Instead of updating the entire multi-billion-parameter weight matrix W\in\mathbb R^{d_{\textrm{out}}\times d_{\textrm{in}}}, we freeze it and inject a much smaller, low-rank update. We learn two lean matrices, A\in\mathbb R^{d_{\textrm{out}}\times r} and B\in\mathbb R^{r\times d_{\textrm{in}}}, such that the effective weight during the forward pass is:

\tilde W = W + A B.

The rank r \ll d_{\textrm{in}} becomes a powerful knob for tuning capacity. A higher r captures more task-specific nuance but costs more VRAM. When using QLoRA, this update is applied on top of the quantized 4-bit base weights, leaving the frozen backbone untouched.

2.3 4‑Bit Quantisation & NF4

QLoRA gets its efficiency from Normal-Float 4 (NF4), a non-uniform 4-bit data type meticulously designed to represent the distribution of weights in a pre-trained neural network. It compresses the backbone to a mere 4 bits per parameter while gradients flow in a higher-precision format (bfloat-16). A second “double-quantization” step compresses the quantization constants themselves, saving another ~0.3 GB on a 2B model. The math ensures that the quantization error is bounded and doesn’t catastrophically flatten the loss landscape.


3 Data Engineering Pipeline

The dataset is the strategy. I constrained this experiment to the cognitivecomputations/dolphin‑r1 reasoning split-roughly 30,000 chain-of-thought examples-to keep the training run manageable overnight on a single GPU.

from datasets import load_dataset
from transformers import AutoTokenizer
import multiprocessing as mp

dataset_name = "cognitivecomputations/dolphin-r1"
split        = "reasoning-deepseek"
ds = load_dataset(dataset_name, split=f"train[:30000]", name=split)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
THINK_OPEN, THINK_CLOSE = "<think>", "</think>"

# --- helper ---

def stitch(row):
    reasoning = f"{THINK_OPEN}{row['reasoning']}{THINK_CLOSE}\n\n{row['answer']}"
    row["messages"].append({"role": "assistant", "content": reasoning})
    row["text"] = tokenizer.apply_chat_template(row["messages"], tokenize=False)
    return row

with mp.Pool() as pool:
    ds = ds.map(stitch, num_proc=pool._processes, load_from_cache_file=False)

3.1 Tokenizer Surgery

Gemma 2’s vocabulary is fixed; we can’t just throw new tokens at it and hope for the best. We must perform some light surgery, formally adding our <think> and </think> sentinels to its lexicon. This ensures dedicated embeddings are created for them, which will be tuned during training.

# Define the special tokens
# THINK_OPEN, THINK_CLOSE are already defined in the script above this block

# Add special tokens to the tokenizer instance
special_tokens_dict = {'additional_special_tokens': [THINK_OPEN, THINK_CLOSE]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_toks} new special tokens: {THINK_OPEN}, {THINK_CLOSE}")

# The model's embedding layer will be resized in the training script (Section 4)
# after the model is loaded and before k-bit training preparation.

4 Training Recipe

graph diagram

I’m using the QLoRA approach for its brutal memory efficiency. The entire fine-tuning process-the 2B backbone, adapters, optimizer state, and activation checkpoints-fits within a lean 8-10 GB footprint on an NVIDIA L4 GPU.

from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
import torch

BASE = "google/gemma-2-2b" # Base model for QLoRA
compute_dtype = torch.bfloat16

# Tokenizer is assumed to be loaded and modified as in Section 3.1
# tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
# THINK_OPEN, THINK_CLOSE = "<think>", "</think>"
# special_tokens_dict = {'additional_special_tokens': [THINK_OPEN, THINK_CLOSE]}
# tokenizer.add_special_tokens(special_tokens_dict)

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(BASE,
        quantization_config=bnb_cfg, device_map="auto")

# Resize token embeddings before preparing for k-bit training
# The tokenizer instance (loaded and modified in Section 3) should be in scope
model.resize_token_embeddings(len(tokenizer))

model = prepare_model_for_kbit_training(model)

lora_cfg = LoraConfig(r=16, lora_alpha=16, lora_dropout=0.05,
                      target_modules=["q_proj","k_proj","v_proj","o_proj",
                                      "gate_proj","up_proj","down_proj"],
                      modules_to_save=["embed_tokens","lm_head"])

train_args = SFTConfig(dataset_text_field="text", max_seq_length=1024,
                      per_device_train_batch_size=2, gradient_accumulation_steps=16,
                      num_train_epochs=1, learning_rate=1e-5, bf16=True,
                      logging_steps=25, save_strategy="epoch")

trainer = SFTTrainer(model=model, train_dataset=ds, peft_config=lora_cfg,
                     tokenizer=tokenizer, args=train_args)
trainer.train()

A note on rank. Don’t be cheap with your LoRA rank r. Bumping it from 16 to 32 bought me another ~3 points of ROUGE-L on thought coherency, for a modest VRAM cost of less than 2 GB on this 2B model. The trade-off is almost always worth it.


5 Fast Evaluation & Common Failure Modes

I test with targeted, tricky prompts. Hallucinations and logic failures love to hide in the edge cases and details, not in broad, generic queries.

from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel
import torch

# Define special tokens (ensure these match training)
THINK_OPEN, THINK_CLOSE = "<think>", "</think>"
special_tokens_dict = {'additional_special_tokens': [THINK_OPEN, THINK_CLOSE]}

# Load the tokenizer. If your training checkpoint saved the tokenizer, load it from there:
# tokenizer = AutoTokenizer.from_pretrained("./LoRA/checkpoint-843/")
# Otherwise, initialize from the instruct model and add tokens:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
if num_added_toks > 0:
    print(f"Added {num_added_toks} special tokens to tokenizer for inference.")

# Define compute_dtype, matching training (e.g., torch.bfloat16)
compute_dtype = torch.bfloat16 

# Load the base model (the one QLoRA was applied to)
base_model_name = "google/gemma-2-2b"
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=compute_dtype,
    device_map="auto"
)
# Resize embeddings of the base model to match the tokenizer
base_model.resize_token_embeddings(len(tokenizer))

# Load the PeftModel (LoRA adapters) onto the base model
# Replace "./LoRA/checkpoint-729/" with your actual adapter path
model = PeftModel.from_pretrained(base_model, "./LoRA/checkpoint-729/")
model.eval() # Set the model to evaluation mode

chat = [
    {"role":"system", "content":"You are a helpful assistant who *thinks* before answering."},
    {"role":"user",   "content":"A rooster flew over to the neighbor's yard and laid an egg. Who does the egg belong to?"}
]

prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
ids    = tokenizer(prompt, return_tensors="pt").to(model.device)

# To use stopping criteria, define the class and list (see below)
# stopping_criteria_list = StoppingCriteriaList([ThinkStopCriteria(tokenizer, THINK_CLOSE)])
# print(tokenizer.decode(model.generate(**ids, max_new_tokens=256, stopping_criteria=stopping_criteria_list)[0]))

print(tokenizer.decode(model.generate(**ids, max_new_tokens=256)[0]))

Failure Mode – Infinite Thinking

If the model gets stuck in a <think></think> loop and never produces an answer:

  1. Increase adapter capacity. Raise the LoRA rank or train for a second epoch. The model needs more capacity to learn the “stop thinking” signal.
  2. Balance the dataset. Ensure your reasoning-to-answer token ratio is at least 1:1. If the thought is always longer than the answer, the model may learn to favor verbose thinking.
  3. Use a programmatic hammer. Brute-force a stop at </think> using a StoppingCriteria callback during generation. This is a patch, not a fix, but it’s effective.
# Example of stopping criteria implementation
# from transformers import StoppingCriteria, StoppingCriteriaList # Ensure imported

class ThinkStopCriteria(StoppingCriteria):
    def __init__(self, tokenizer, end_token_str):
        # Encode the end token string. add_special_tokens=False is important here
        # as we are looking for the specific token(s) representing end_token_str.
        self.end_token_ids = tokenizer.encode(end_token_str, add_special_tokens=False)
        self.end_token_len = len(self.end_token_ids)
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Check each sequence in the batch
        for seq_ids in input_ids:
            # Check if the tail of the sequence matches the end_token_ids
            if len(seq_ids) >= self.end_token_len:
                if seq_ids[-self.end_token_len:].tolist() == self.end_token_ids:
                    return True # Stop generation for the whole batch if one sequence meets criteria
        return False

# Usage in generation (ensure THINK_CLOSE is defined):
# stopping_criteria_list = StoppingCriteriaList([ThinkStopCriteria(tokenizer, THINK_CLOSE)])
# output_ids = model.generate(**ids, max_new_tokens=256, stopping_criteria=stopping_criteria_list)
# print(tokenizer.decode(output_ids[0]))

6 Anatomy of a Successful Generation

The proof is in the output. Here, the model correctly identifies the flawed premise before committing to an answer.

<think>
The user is asking who owns an egg laid by a rooster.
A rooster is a male chicken.
Male chickens (roosters) do not lay eggs.
The premise of the question is flawed. The egg does not exist.
</think>
Roosters don't lay eggs, so the egg doesn't belong to anyone.

The internal monologue deconstructs the problem, spots the logical trap, and then informs the final, concise answer. This is the behavior we were aiming for.


7 Visual Pipeline Summary

Data Preparation


8 Limitations & Next Steps

  • Mimicry isn’t understanding. Supervising on CoT outputs teaches the model to imitate a reasoning process. It doesn’t guarantee the model has acquired a generalized, robust reasoning capability.
  • Beware data contamination. The Dolphin-R1 dataset contains parts of common benchmarks like GSM8K. Be honest about this. Don’t claim benchmark-beating performance without acknowledging potential leakage.
  • Privacy is a feature, not an afterthought. The model’s “thoughts” can leak sensitive data present in the prompt. For production systems, you need a strategy to redact or filter this internal monologue before it’s logged or exposed.

9 Further Reading


Closing Thoughts

With little more than two special tokens and a low-rank adapter, we can teach a 2B Gemma 2 to articulate its internal monologue with the discipline of DeepSeek-R1. This experiment is a potent reminder that in the world of AI, algorithmic insight and clever data framing often deliver more value than just throwing more parameters at the problem. The techniques are out there. Go build better thinkers.

Posted in AI / ML, LLM Advanced