RAG(检索增强生成)是一种结合了两个关键组成部分的技术:
检索: 首先,它搜索知识库(如文档、数据库等)以查找给定查询的相关信息。这通常包括:
- 将文本转换为嵌入向量(表示含义的数字向量)
- 使用相似度度量(如余弦相似度)查找相似内容
- 选择最相关的信息
代: 然后,它使用语言模型(如我们代码中的 T5)通过以下方式生成响应:
将检索到的信息与原始问题相结合
根据此上下文创建自然语言响应
在代码中:
- SentenceTransformer 通过创建嵌入来处理检索部分
- T5 模型通过创建答案来处理生成部分
RAG 的好处:
- 更准确的响应,因为它们基于特定知识
- 与纯 LLM 反应相比,幻觉减少
- 能够访问最新信息或特定于域的信息
- 比纯发电更可控、更透明
系统架构概述
该实现由一个 SimpleQASystem 类组成,该类编排两个主要组件:
- 使用 Sentence Transformers 的语义搜索系统
- 使用 T5 的答案生成系统
您可以在此处下载最新版本的源代码:github.com/alexander-u…
系统图
RAG 项目设置指南
本指南将帮助您在 macOS 和 Windows 上设置 Retrieval-Augmented Generation (RAG) 项目。
先决条件
对于 macOS:
安装 Homebrew(如果尚未安装):
使用 Homebrew
安装 Python 3.8+ 对于 Windows:
从 python.org 下载并安装 Python 3.8+
确保在安装过程中选中“将 Python 添加到 PATH”/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"``brew install python@3.10
项目设置
第 1 步:创建项目目录
macOS 版本:
mkdir RAG_project cd RAG_project
窗户:
mkdir RAG_project cd RAG_project
第 2 步:设置虚拟环境
macOS 版本:
python3 -m venv venv source venv/bin/activate
窗户:
python -m venv venv venv\Scripts\activate
**核心组件
- 初始化**
def __init__(self):
self.model_name = 't5-small'
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
self.encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2')
系统使用两个主要模型进行初始化:
T5-small:用于生成答案
的 T5 模型的较小版本paraphrase-MiniLM-L6-v2:用于将文本编码为有意义向量的句子转换模型
2. 数据集准备
def prepare_dataset(self, data: List[Dict[str, str]]):
self.answers = [item['answer'] for item in data]
self.answer_embeddings = []
for answer in self.answers:
embedding = self.encoder.encode(answer, convert_to_tensor=True)
self.answer_embeddings.append(embedding)
数据集准备阶段:
- 从输入数据中提取答案
- 使用句子转换器为每个答案创建嵌入向量
- 存储答案及其嵌入内容以便快速检索
系统如何运作
1. 问题处理
当用户提交问题时,系统会遵循以下步骤:
Embedding Generation:使用用于答案的相同句子转换模型将问题转换为向量表示。
语义搜索:系统通过以下方式找到最相关的存储答案:
- 计算问题嵌入向量和所有答案嵌入向量之间的余弦相似度
- 选择相似度得分最高的答案 Context Formation:所选答案成为 T5 生成最终响应的上下文。
2. 答案生成
def get_answer(self, question: str) -> str:
# ... semantic search logic ...
input_text = f"Given the context, what is the answer to the question: {question} Context: {context}"
input_ids = self.tokenizer(input_text, max_length=512, truncation=True,
padding='max_length', return_tensors='pt').input_ids
outputs = self.model.generate(input_ids, max_length=50, num_beams=4,
early_stopping=True, no_repeat_ngram_size=2
答案生成过程:
- 将问题和上下文合并到 T5 的提示中
- 对输入文本进行分词,最大长度为 512 个分词
- 使用带有以下参数的 beam 搜索生成答案:
- max_length=50:限制答案长度
- num_beams=4:使用具有 4 个光束的光束搜索
- early_stopping=True:当所有光束都到达结束标记时停止生成
- no_repeat_ngram_size=2:防止重复二元语法
3. 答案清理
def clean_answer(self, answer: str) -> str:
words = answer.split()
cleaned_words = []
for i, word in enumerate(words):
if i == 0 or word.lower() != words[i-1].lower():
cleaned_words.append(word)
cleaned = ' '.join(cleaned_words)
return cleaned[0].upper() + cleaned[1:] if cleaned else cleaned
- 删除重复的连续单词(不区分大小写)
- 将答案的首字母大写
- 删除多余的空格
完整源代码
您可以在此处下载最新版本的源代码:github.com/alexander-u…
import os
# Set tokenizers parallelism before importing libraries
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from typing import List, Dict
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
class SimpleQASystem:
def __init__(self):
"""Initialize QA system using T5"""
try:
# Use T5 for answer generation
self.model_name = 't5-small'
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name, legacy=False)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
# Move model to CPU explicitly to avoid memory issues
self.device = "cpu"
self.model = self.model.to(self.device)
# Initialize storage
self.answers = []
self.answer_embeddings = None
self.encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2')
print("System initialized successfully")
except Exception as e:
print(f"Initialization error: {e}")
raise
def prepare_dataset(self, data: List[Dict[str, str]]):
"""Prepare the dataset by storing answers and their embeddings"""
try:
# Store answers
self.answers = [item['answer'] for item in data]
# Encode answers using SentenceTransformer
self.answer_embeddings = []
for answer in self.answers:
embedding = self.encoder.encode(answer, convert_to_tensor=True)
self.answer_embeddings.append(embedding)
print(f"Prepared {len(self.answers)} answers")
except Exception as e:
print(f"Dataset preparation error: {e}")
raise
def clean_answer(self, answer: str) -> str:
"""Clean up generated answer by removing duplicates and extra whitespace"""
words = answer.split()
cleaned_words = []
for i, word in enumerate(words):
if i == 0 or word.lower() != words[i-1].lower():
cleaned_words.append(word)
cleaned = ' '.join(cleaned_words)
return cleaned[0].upper() + cleaned[1:] if cleaned else cleaned
def get_answer(self, question: str) -> str:
"""Get answer using semantic search and T5 generation"""
try:
if not self.answers or self.answer_embeddings is None:
raise ValueError("Dataset not prepared. Call prepare_dataset first.")
# Encode question using SentenceTransformer
question_embedding = self.encoder.encode(
question,
convert_to_tensor=True,
show_progress_bar=False
)
# Move the question embedding to CPU (if not already)
question_embedding = question_embedding.cpu()
# Find most similar answer using cosine similarity
similarities = cosine_similarity(
question_embedding.numpy().reshape(1, -1), # Use .numpy() for numpy compatibility
np.array([embedding.cpu().numpy() for embedding in self.answer_embeddings]) # Move answer embeddings to CPU
)[0]
best_idx = np.argmax(similarities)
context = self.answers[best_idx]
# Generate the input text for the T5 model
input_text = f"Given the context, what is the answer to the question: {question} Context: {context}"
print(input_text)
# Tokenize input text
input_ids = self.tokenizer(
input_text,
max_length=512,
truncation=True,
padding='max_length',
return_tensors='pt'
).input_ids.to(self.device)
# Generate answer with limited max_length
outputs = self.model.generate(
input_ids,
max_length=50, # Increase length to handle more detailed answers
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=2
)
# Decode the generated answer
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Print the raw generated answer for debugging
print(f"Generated answer before cleaning: {answer}")
# Clean up the answer
cleaned_answer = self.clean_answer(answer)
return cleaned_answer
except Exception as e:
print(f"Error generating answer: {e}")
return f"Error: {str(e)}"
def main():
"""Main function with sample usage"""
try:
# Sample data
data = [
{"question": "What is the capital of France?", "answer": "The capital of France is Paris."},
{"question": "What is the largest planet?", "answer": "The largest planet is Jupiter."},
{"question": "Who wrote '1984'?", "answer": "George Orwell wrote '1984'."}
]
# Initialize system
print("Initializing QA system...")
qa_system = SimpleQASystem()
# Prepare dataset
print("Preparing dataset...")
qa_system.prepare_dataset(data)
# Start interactive Q&A session
while True:
# Prompt the user for a question
test_question = input("\nPlease enter your question (or 'exit' to quit): ")
if test_question.lower() == 'exit':
print("Exiting the program.")
break
# Get and print the answer
print(f"\nQuestion: {test_question}")
answer = qa_system.get_answer(test_question)
print(f"Answer: {answer}")
except Exception as e:
print(f"Error in main: {e}")
if __name__ == "__main__":
main()Performance Considerations
内存管理:
系统显式使用 CPU 以避免内存问题
需要
时将嵌入转换为 CPU 张量 输入长度限制为 512 个标记
错误处理:
- 整个代码中全面的 try-except 块
- 用于调试的有意义的错误消息
- 未初始化组件的验证检查
使用示例
# Initialize system
qa_system = SimpleQASystem()
# Prepare sample data
data = [
{"question": "What is the capital of France?", "answer": "The capital of France is Paris."},
{"question": "What is the largest planet?", "answer": "The largest planet is Jupiter."}
]
# Prepare dataset
qa_system.prepare_dataset(data)
# Get answer
answer = qa_system.get_answer("What is the capital of France?")
在终端中运行
限制和可能的改进
可扩展性:
当前的实现将所有嵌入保留在内存中
,可以通过向量数据库进行改进,以用于大规模应用程序
答案质量:
严重依赖所提供的答案数据集
的质量 受 T5-small
上下文窗口的限制 可以从答案验证或置信度评分中受益
性能:
- 对于大型应用程序,仅使用 CPU 可能会更慢
- 可以通过批处理进行优化
- 可以为常见问题实施缓存
结论
此实现为问答系统提供了坚实的基础,结合了语义搜索和基于 transformer 的文本生成的优势。随意使用模型参数(如 max_length、num_beams、early_stopping、no_repeat_ngram_size 等)以找到更好的方法来获得更连贯和稳定的答案。虽然还有改进的余地,但当前的实现在复杂性和功能之间提供了良好的平衡,使其适用于教育目的和中小型应用程序。