LoReFT and pyreft for surgical fine-tuning

omills

Olivier

Posted on April 7, 2024

LoReFT and pyreft for surgical fine-tuning

Here’s trying to understand and summarise the paper "ReFT: Representation Finetuning for Language Models"

https://arxiv.org/abs/2404.03592
https://github.com/stanfordnlp/pyreft

The paper proposes Representation Finetuning (ReFT) as a more parameter-efficient alternative to PEFTs like adapters and LoRA for adapting large language models to downstream tasks. Key ideas:

  • ReFT methods train task-specific interventions that modify hidden representations of a frozen base model, leveraging the insight from interpretability work that representations encode rich semantics.

  • The Low-rank Linear Subspace ReFT (LoReFT) variant uses 10-50x fewer parameters than leading PEFTs by intervening on representations in a learned low-dimensional subspace.

  • LoReFT provides state-of-the-art efficiency-performance tradeoffs on commonsense reasoning, arithmetic, instruction following, and language understanding tasks. On instruction tuning Llama-2 7B, it nearly matches GPT-3.5 using only 0.004% extra parameters.

  • Editing representations may be more powerful than modifying weights as done in PEFTs. LoReFT did struggle more on arithmetic reasoning, possibly due to long output lengths.

To make ReFT easy to use, the authors also introduce pyreft, a library that enables:

  • Fine-tuning any HuggingFace pretrained LM with ReFT
  • Configuring ReFT hyperparameters via config files
  • Easily sharing fine-tuned models on HuggingFace

Here is some pseudo-code showing a basic usage of pyreft to fine-tune a Llama-7B model:

from pyreft import get_reft_model, ReftConfig, ReftTrainerForCausalLM

# Load pretrained LM
model = AutoModelForCausalLM.from_pretrained("llama-7b-hf") 

# Configure LoReFT intervention
reft_config = ReftConfig(
  representations={
    "layer": 15, 
    "component": "block_output",
    "intervention": LoReftIntervention(embed_dim=model.hidden_size, rank=1)
  }  
)

# Wrap model with ReFT  
reft_model = get_reft_model(model, reft_config)

# Load training data
data_module = make_supervised_data_module(...)

# Configure training
trainer = ReftTrainerForCausalLM(
  model=reft_model,
  args=TrainingArguments(...),  
  **data_module
)

# Train model
trainer.train()
Enter fullscreen mode Exit fullscreen mode

This fine-tunes a Llama-7B model by training a rank-1 LoReFT intervention at layer 15, modifying only 0.00006% of the model parameters. Once trained, the reft_model can be used for inference on downstream tasks.

By establishing ReFT as a promising new paradigm for LM adaptation, this work points to representation editing as a powerful lever for both model efficiency and interpretability. The pyreft library enables researchers and practitioners to easily experiment with and build on these ideas.

💖 💪 🙅 🚩
omills
Olivier

Posted on April 7, 2024

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related