Unboxing LLMs > loading...

April 1, 2025

Teaching Llama New Tricks: Memory-Efficient Trainable Tokens with LoRA

Key Takeaways

  • LoRA adapters are a scalpel for fine-tuning billion-parameter LLMs without touching the frozen backbone, but they cannot teach the model new words on their own.
  • Trainable tokens perform microsurgery, patching only the specific rows in the embedding table (and optionally the language-model head) for new tokens, instead of retraining the entire matrix-slashing VRAM requirements by gigabytes.
  • Hugging Face PEFT offers two primary entry points: CustomTokensConfig for sparse embeddings alone, and the more powerful LoraConfig.trainable_token_indices to combine them with LoRA.
  • My benchmarks on Llama 3 show that updating just four chat-template tokens added a trivial 16 kB – 32 kB of weights but saved ~5 GB in VRAM compared to a full embedding retrain, with little to no hit on quality.
  • The non-obvious win: Retraining the language model (LM) head for the same token indices almost always improves perplexity, though it requires materializing the full head matrix, costing a couple of extra gigabytes on 8B models.
  • A word of caution: Use a conservative learning rate for embeddings (e.g., 1-3 × 10⁻⁵) and watch for NaNs. Newly initialized special token vectors sitting at near-zero can easily destabilize training.

Why Fresh Embeddings Matter

Large language models build a sophisticated, statistical map of their vocabulary during pre-training. When we introduce special tokens-delimiters like <|user|> or <|assistant|>, domain-specific terms, or new code keywords-their corresponding embedding vectors are often uninitialized or meaningless. Consequently, a frozen model, or even one fine-tuned with standard LoRA, will treat these tokens as noise because they carry no learned semantic weight.

The brute-force solutions-full fine-tuning or LoRA with a full embedding layer retrain-are memory hogs. They require back-propagating gradients through the entire embedding matrix, which can be massive. The Llama 3 8B model, for instance, has a vocabulary size (V) of 128,000 and an embedding dimension (D) of 4096. This means an embedding matrix with over 500 million parameters.

The Memory Arithmetic

Let’s look at the brutal math: \begin{aligned}
\textrm{params} &= V \times D \\[2pt]
\textrm{weights}_{16\textrm{-bit}} &= 2 \times \textrm{params} \;\textrm{bytes}\\[2pt]
\textrm{Adam states}_{32\textrm{-bit}} &= (2 \times 4) \times \textrm{params}\;\textrm{bytes}\\[2pt]
\textrm{total memory for embeddings} &\approx 10 \times \textrm{params}\;\textrm{bytes}
\end{aligned}

With V = 128,000 and D = 4096 for Llama 3 8B, the cost is:

\textrm{total memory} \approx 128,000 \times 4,096 \times 10 \approx 5.2 \textrm{ GB}

That’s over 5 GB of GPU memory torched just for the embedding layer and its optimizer states. A non-starter on most reasonable hardware.

Enter Trainable Tokens

The fix, it turns out, is elegant. Instead of sledgehammering the entire dense (V,D) parameter matrix, we create a sparse patch-a surgical graft-of size (n, D), where n is the small number of new tokens we need to train (often just a handful), and n \ll V.

During the forward pass, this tiny, trainable implant is effectively overlayed onto the otherwise frozen base embedding table. Gradients are calculated and back-propagated only through these n rows.

Conceptually, the data flow looks like this:

Forward Pass with Trainable Tokens

Because gradients only flow through these specified rows, the optimizer only needs to maintain states for kilobytes of parameters, not gigabytes, yielding massive memory savings.

Two Ways to Activate the Feature in PEFT

The Hugging Face Parameter-Efficient Fine-Tuning (PEFT) library gives us two clean entry points for this.

Sparse embeddings only

If your only goal is to train new token embeddings without applying LoRA to other layers, CustomTokensConfig is the most direct path. It targets specific modules (like embed_tokens) and updates only the specified token_indices.

from peft import CustomTokensConfig, get_peft_model

# Example: Llama 3 chat templates often add new special tokens.
# Assuming token IDs 128006-128009 are new chat tokens.
peft_cfg = CustomTokensConfig(
    target_modules=['embed_tokens'],  # Target the embedding layer
    token_indices=[128006, 128007, 128008, 128009]
)
# base_model is your pre-trained Hugging Face model
# model = get_peft_model(base_model, peft_cfg)

Sparse embeddings + LoRA adapters

The more common scenario, and the one I typically use, is to bolt this onto a standard LoRA setup. The LoraConfig in PEFT lets you specify trainable_token_indices directly. You can choose to update these tokens in the input embedding layer (embed_tokens) and optionally in the output LM head (lm_head).

from peft import LoraConfig, get_peft_model

# Continuing with the Llama 3 chat token example
peft_cfg = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj',
                    'gate_proj', 'down_proj', 'up_proj'], # Standard LoRA targets
    trainable_token_indices={
        'embed_tokens': [128006, 128007, 128008, 128009], # Train these in the input embeddings
        'lm_head': [128006, 128007, 128008, 128009]      # Optionally, train them in the output LM head
    }
)
# base_model is your pre-trained Hugging Face model
# model = get_peft_model(base_model, peft_cfg)

A Walk-Through in Code

Let’s ground this in a practical code structure I use. This function sets up the model and tokenizer, then applies the PEFT configuration for LoRA with trainable tokens.

import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer

# A simplified function structure
def fine_tune_with_trainable_tokens(
        model_name: str,
        dataset_name: str = "HuggingFaceH4/ultrachat_200k",
        new_chat_tokens: list[int] | None = None, # e.g., [128006, 128007, 128008, 128009]
        train_lm_head_for_new_tokens: bool = True,
        # ... other training parameters
):
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Example: Add a padding token if not present (common for fine-tuning)
    # Ensure this new pad token ID is NOT one of your trainable_tokens unless intended.
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) # Or use an existing unused special token
        model.resize_token_embeddings(len(tokenizer))
        # Note: If you add tokens this way, their embeddings are new and might also need training.
        # For this article's focus, we assume new_chat_tokens are already in the tokenizer's vocab
        # but need their embeddings specifically trained.


    # Load and prepare dataset
    # This is a placeholder for actual dataset loading and preprocessing
    dataset = load_dataset(dataset_name, split="train_sft[:1%]") # Using a small slice for example
    
    # Define instruct_template based on your model and task
    def instruct_template(messages):
        # Simplified example: concat messages. Real templates are more complex.
        # Ensure your new_chat_tokens are used in this template if they are chat delimiters.
        return "".join(msg['content'] for msg in messages)

    processed_dataset = dataset.map(
        lambda row: {"text": instruct_template(row["messages"])},
        num_proc=os.cpu_count()
    )

    # PEFT Configuration
    trainable_token_config = {}
    if new_chat_tokens:
        trainable_token_config['embed_tokens'] = new_chat_tokens
        if train_lm_head_for_new_tokens:
            trainable_token_config['lm_head'] = new_chat_tokens
    
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=['k_proj','q_proj','v_proj','o_proj', 'gate_proj','down_proj','up_proj'],
        trainable_token_indices=trainable_token_config if trainable_token_config else None,
        modules_to_save=None # Only if not using trainable_token_indices for embeddings/lm_head
    )

    # Initialize Trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=processed_dataset,
        peft_config=peft_config,
        dataset_text_field="text",
        max_seq_length=1024,
        # ... other SFTTrainer arguments (TrainingArguments, etc.)
    )

    # Start training
    # trainer.train()
    print("PEFT config ready for training.")
    return # trainer.train() # Uncomment to run

# Example call:
# fine_tune_with_trainable_tokens(
#     model_name="meta-llama/Llama-2-7b-hf", # Replace with your model
#     new_chat_tokens=[tokenizer.convert_tokens_to_ids("<|user|>"), tokenizer.convert_tokens_to_ids("<|assistant|>")] # Example
# )

The entire mechanism hinges on trainable_token_indices in the LoraConfig. That’s the lever you pull. If you were to use modules_to_save instead, you’d be back to retraining the entire embedding layer, which is precisely the memory-hogging operation we’re trying to avoid.

Experiments on UltraChat-10k

To move beyond theory, I ran this through its paces, fine-tuning Llama 3 models on a 10,000-sample subset of the UltraChat dataset. I focused on updating four special tokens for a chat template. The results for the 8B parameter model were telling:

Configuration Extra Trainable Parameters Peak VRAM (Llama 3 8B) Memory Δ vs. LoRA-baseline Validation Loss
LoRA, frozen embeddings (baseline) 23.8 GB baseline 1.65
Trainable tokens (4 tokens in embed_tokens only) +16.3k (4 × 4096) 23.9 GB +0.1 GB 1.78
Trainable tokens + LM head (4 tokens in embed_tokens and lm_head) +32.7k (2 × 4 × 4096) 25.9 GB +2.1 GB 1.60
Full embedding + LM head retrain (LoRA + modules_to_save) +524 M 30.9 GB +7.1 GB 1.63

Note: Parameter counts are approximate for the embedding/LM head parts. VRAM measured during training with bfloat16, batch size 1, gradient accumulation 32.

Observations from my experiments:

  • The Cheap Path Has a Toll: Training only the input embedding vectors is seductively memory-efficient, but my results show a performance hit (higher validation loss). It appears the model struggles to propagate the meaning of these new tokens all the way through if the LM head remains unaware. It learns a new word but doesn’t know how to say it.
  • The Sweet Spot: The real magic happens when you train the LM head in tandem. Allowing the LM head’s weights for these specific tokens to train as well delivered the best validation loss, outperforming both the baseline and the full-retrain behemoth.
  • The Payoff: The VRAM savings are substantial. Even with the ~2 GB penalty for patching the LM head (the full matrix is temporarily materialized for the optimizer), this approach still saves roughly 5 GB compared to a full retrain. The trainable parameter count itself is minuscule.

Training Speed & Gotchas

Here’s the counter-intuitive part: fewer parameters don’t necessarily mean faster epochs. The process of overlaying sparse patches involves scatter-gather operations that aren’t free. When training the LM head for specific tokens, PEFT may need to materialize the full dense head matrix before the optimizer step. In my experience, this can add 5-10% to your step times compared to standard LoRA with frozen embeddings. A small price to pay for the memory savings.

Important Tip: My non-negotiable advice here is to use differential learning rates. Keep the LoRA adapter learning rate standard (e.g., 1 × 10⁻⁴), but apply a more conservative rate for the embeddings and LM head rows (e.g., 1 × 10⁻⁵ to 3 × 10⁻⁵). This prevents the new token embeddings from going rogue and destabilizing training. New, near-zero vectors are particularly sensitive to high learning rates and can easily trigger NaN loss.

When (Not) to Use Trainable Tokens

This isn’t a silver bullet. It’s a scalpel, and knowing when to use it is key.

Use-case My Recommendation
Adding chat delimiters or a few <reserved_special_token_*> tokens Ideal. A small patch (2-10 tokens) is highly effective and memory-efficient.
Introducing domain-specific jargon (e.g., clinical, legal) where coverage is poor Yes, especially when the number of new essential terms is modest (e.g., ≲ 100).
Fine-tuning code models that require new operators or keywords Strongly recommended. This can be combined effectively with LoRA on other layers.
Adding a massive amount of new vocabulary (e.g., >1,000 tokens for extensive slang or new language support) Maybe not. For huge vocabulary expansions, a full embedding retrain or another round of pre-training on a relevant corpus is likely more appropriate.

Putting It All Together

For my money, LoRA with trainable tokens is the optimal trade-off. It keeps my GPU memory footprint lean, almost on par with classic LoRA, while giving me the power to teach the model essential new vocabulary. The option to also train the corresponding LM head rows, with a simple dictionary entry, is an easy lever to pull for a significant quality boost.

For practitioners fine-tuning LLMs on commodity hardware, the playbook is simple:

  1. Identify the exact token IDs you need to make trainable via tokenizer.convert_tokens_to_ids().
  2. Decide if updating the LM head will benefit your task. For most instruction-tuning and chat applications, the answer is yes.
  3. Update your LoraConfig by setting trainable_token_indices for embed_tokens and lm_head.
  4. Run your training job, confident that your memory usage won’t explode.

Conclusion

As we drag these models out of the research lab and into specialized, messy reality, the ability to efficiently adapt their vocabulary becomes non-negotiable. By selectively training only the necessary embedding rows for new tokens, we overlay sparse, trainable slices onto a frozen backbone. This approach reclaims both GPU memory and fine-grained control over the model’s lexicon, often without a discernible performance hit-and sometimes, as my tests show, with a clear improvement.

I’ve found this technique indispensable for rapid iteration on chat systems, specialized instruction-following models, and code assistants. Given its efficiency and the ease of implementation with libraries like PEFT, trainable tokens should be a standard tool in every LLM fine-tuner’s kit. It’s a small trick, but it unlocks a surprising amount of flexibility. Go build.

Posted in AI / ML, LLM Intermediate