RAG-Tuning: Towards Generalized Retrieval-Augmented Generation with Data Synthesis
Paper: would be released soon
Abstract: Retrieval-Augmented Generation (RAG) extends the generative capabilities of large language models by integrating external document retrieval. Most existing studies train and evaluate RAG models predominantly on Knowledge-Intensive Question Answering (KIQA) scenarios. However, this focus often leads to overfitting on KIQA tasks, resulting in poor performance in other RAG scenarios, which is considered not generalized enough. To address this challenge, we introduce RAG-Tuning, a novel approach designed to improve the generalization of RAG models through data synthesis. Our method involves using GPT-4 to generate large-scale training data covering a wide range of RAG scenarios with diverse query topics and formats from scratch. To better mimic real-world RAG conditions, for the passages in training data, we also combine passages from different dense and sparse retrievers for different queries and introduce random noise into the retrieved passages. Unlike previous RAG studies that primarily focus on evaluation on KIQA tasks, we emphasize the importance of evaluating RAG-Tuning on open RAG tasks in addition to traditional in-domain and out-of-domain KIQA benchmarks. Experimental results show significant improvements across all these benchmarks. Additionally, RAG-Tuning demonstrates enhanced generalization when paired with different retrievers or when the retrieved documents are noisy.
Models: We release the RT-KS and RT-S in this Huggingface Repo. Please use the following codes to load the models (remember to download the original Mistral-7B model first):
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel # , PeftConfig
def load_fine_tuned_model(model_name, adapters_name=None):
"""
model_name: original Mistral-7B path
adapters_name: RT-KS or RT-S path
"""
def set_tokenizer(tokenizer):
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
return tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = set_tokenizer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map={"": 0}
)
if adapters_name:
model = PeftModel.from_pretrained(model, adapters_name)
model = model.merge_and_unload()
return model, tokenizer