第三章:实现 RAG 核心逻辑
3.1 引言
在前两章中,我们已经完成了 RAG 系统的基础工作:第一章介绍了 RAG 的概念和开发环境配置,第二章实现了知识库的构建、文档向量化以及基本检索功能。本章将进入 RAG 系统的核心部分,详细讲解如何整合检索和生成过程,构建一个完整的 RAG 流程。
本章的目标是帮助 Java 程序员实现以下内容:
- 使用
langchain
框架整合嵌入模型、向量数据库和生成模型。 - 实现 RAG 的完整流程:查询处理、检索、上下文增强和生成。
- 优化 RAG 系统的性能和准确性,处理常见问题。
- 为后续的 RESTful API 封装做好准备。
我们将通过详细的代码示例,带你完成从查询输入到生成回答的完整流程。本章假设你已经按照前两章的说明配置好开发环境,并成功运行了知识库向量化代码。
3.2 RAG 核心流程概述
RAG 的核心流程可以分为以下四个步骤:
- 查询处理:接收用户输入的查询(query),对其进行预处理(如清理、规范化)和向量化。
- 文档检索:使用向量数据库(如 FAISS)查找与查询最相关的文档片段。
- 上下文增强:将检索到的文档与查询组合,形成一个丰富的输入提示(prompt)。
- 生成回答:将增强后的提示输入到生成模型,生成最终的自然语言回答。
为了简化开发,我们将使用 langchain
框架,它提供了一套高层次的 API,能够无缝整合嵌入模型、向量数据库和生成模型。langchain
的模块化设计非常适合 RAG 系统,特别适合没有 AI 开发经验的开发者。
3.3 配置 LangChain 环境
3.3.1 安装 LangChain
确保已安装 langchain
及其相关依赖:
pip install langchain langchain-community langchain-huggingface
langchain
:核心框架,提供 RAG 流程的抽象。langchain-community
:包含社区贡献的工具,如 FAISS 集成。langchain-huggingface
:支持 Hugging Face 的嵌入和生成模型。
3.3.2 选择生成模型
生成模型是 RAG 系统的核心组件,负责生成最终的回答。以下是常见的选择:
- 开源模型:如 Mistral、LLaMA 或 Phi-3,适合本地部署。
- 云端模型:如 OpenAI 的 GPT-4 或 Anthropic 的 Claude,需要 API 密钥。
- 轻量模型:如
google/flan-t5-base
,适合资源受限的场景。
在本教程中,我们将使用 Hugging Face 的 google/flan-t5-base
模型,因为它:
- 体积小(约 250MB),适合本地运行。
- 支持生成任务,性能在小型模型中表现良好。
- 开源免费,无需额外费用。
安装依赖:
pip install transformers torch
测试生成模型:
以下是加载和测试 flan-t5-base
模型的代码:
from transformers import T5Tokenizer, T5ForConditionalGeneration
# 加载模型和分词器
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
# 示例生成
input_text = "Answer the question: What is the capital of France?"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Answer: {answer}")
输出:
Answer: The capital of France is Paris.
说明:
flan-t5-base
是一个基于 T5 的模型,适合问答和文本生成任务。- 如果你有 GPU,可以通过
model.to('cuda')
加速推理。
3.4 使用 LangChain 实现 RAG
3.4.1 初始化 RAG 组件
我们将使用 langchain
的以下组件:
- Embedding Model:
sentence-transformers/all-MiniLM-L6-v2
,用于查询和文档向量化。 - Vector Store:FAISS,存储文档嵌入。
- LLM:
google/flan-t5-base
,用于生成回答。 - Prompt Template:定义输入提示的格式。
以下是初始化这些组件的代码:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from transformers import pipeline
# 初始化嵌入模型
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# 初始化生成模型
text_generation_pipeline = pipeline(
"text2text-generation",
model="google/flan-t5-base",
tokenizer="google/flan-t5-base",
max_length=512,
device=0 if torch.cuda.is_available() else -1
)
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# 初始化提示模板
prompt_template = """
You are a helpful assistant. Use the following context to answer the question.
Context:
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template
)
代码说明:
HuggingFaceEmbeddings
:封装sentence-transformers
的嵌入模型。HuggingFacePipeline
:将 Hugging Face 的生成模型适配为 LangChain 的 LLM 接口。PromptTemplate
:定义 RAG 的输入格式,包含上下文和问题。
3.4.2 加载知识库
假设你已经按照第二章的代码生成了 FAISS 索引和向量化文档,我们可以直接加载:
# 加载 FAISS 向量存储
vector_store = FAISS.load_local(
"rag_knowledge_base_index.faiss",
embeddings=embeddings,
allow_dangerous_deserialization=True
)
注意:
allow_dangerous_deserialization=True
是为了加载本地 FAISS 索引,生产环境中需要谨慎使用。- 如果你没有现成的索引,可以参考第二章重新生成。
3.4.3 构建 RAG 链
LangChain 提供了一个 RetrievalQA
链,用于整合检索和生成过程。以下是构建 RAG 链的代码:
from langchain.chains import RetrievalQA
# 创建 RAG 链
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": prompt}
)
代码说明:
RetrievalQA
:LangChain 的问答链,自动处理检索和生成。chain_type="stuff"
:将所有检索到的文档直接“塞入”提示中,适合小型文档集。search_kwargs={"k": 3}
:检索 top-3 相关文档。chain_type_kwargs
:指定自定义的提示模板。
3.4.4 测试 RAG 系统
现在,我们可以测试完整的 RAG 流程:
# 示例查询
query = "Which operating systems are supported by the product?"
response = rag_chain.invoke({"query": query})
print(f"Question: {query}")
print(f"Answer: {response['result']}")
示例输出:
Question: Which operating systems are supported by the product?
Answer: Our product supports Windows 10/11, macOS 12+, and Ubuntu 20.04+.
分析:
- RAG 系统首先检索包含操作系统信息的 FAQ 文档片段。
- 检索到的文档作为上下文,与查询一起输入到提示模板。
flan-t5-base
模型根据上下文生成准确的回答。
3.5 优化 RAG 系统
3.5.1 提示工程(Prompt Engineering)
提示模板对生成质量有很大影响。以下是优化提示模板的建议:
- 明确指令:在提示中清晰说明模型的角色和任务。
- 结构化上下文:将上下文和问题分开,避免混淆。
- 限制输出长度:在提示中指定回答的简洁性或详细程度。
改进后的提示模板:
prompt_template = """
You are a knowledgeable assistant. Use the provided context to answer the question accurately and concisely. If the context does not contain enough information, say so.
Context:
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template
)
重新构建 RAG 链:
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": prompt}
)
3.5.2 检索优化
检索质量直接影响 RAG 的表现。以下是优化检索的策略:
-
调整 top-k:
- 增大
k
(如 5 或 10)可以包含更多相关文档,但可能引入噪声。 - 减小
k
(如 1 或 2)可以提高精确度,但可能丢失上下文。
- 增大
-
使用更好的嵌入模型:
-
替换为
all-mpnet-base-v2
,它比all-MiniLM-L6-v2
更强大:embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
-
-
FAISS 索引优化:
-
对于大型知识库,使用
IndexIVFFlat
索引:index = faiss.IndexIVFFlat(index, dimension, nlist=100) index.train(embedding_vectors)
-
3.5.3 生成模型优化
flan-t5-base
是一个轻量模型,可能在复杂任务中表现有限。以下是优化建议:
-
使用更大的模型:
-
替换为
google/flan-t5-large
或mistralai/Mixtral-8x7B
(需要更多内存和 GPU)。 -
示例:
text_generation_pipeline = pipeline( "text2text-generation", model="google/flan-t5-large", tokenizer="google/flan-t5-large", max_length=512 )
-
-
调整生成参数:
-
控制输出的多样性和长度:
text_generation_pipeline = pipeline( "text2text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base", max_length=512, do_sample=True, temperature=0.7, top_p=0.9 )
-
-
使用云端模型:
-
如果本地资源有限,可以使用 OpenAI 的 GPT-4:
from langchain_openai import OpenAI llm = OpenAI(api_key="your-openai-api-key", model="gpt-4")
-
3.6 完整 RAG 实现
以下是将所有功能整合到一个完整脚本的代码,包含初始化、RAG 链构建和测试:
import faiss
import torch
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import pipeline
class RAGSystem:
def __init__(self, index_path, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", llm_model_name="google/flan-t5-base"):
# 初始化嵌入模型
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
# 加载 FAISS 向量存储
self.vector_store = FAISS.load_local(
index_path,
embeddings=self.embeddings,
allow_dangerous_deserialization=True
)
# 初始化生成模型
self.text_generation_pipeline = pipeline(
"text2text-generation",
model=llm_model_name,
tokenizer=llm_model_name,
max_length=512,
device=0 if torch.cuda.is_available() else -1
)
self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline)
# 初始化提示模板
self.prompt_template = """
You are a knowledgeable assistant. Use the provided context to answer the question accurately and concisely. If the context does not contain enough information, say so.
Context:
{context}
Question: {question}
Answer:
"""
self.prompt = PromptTemplate(
input_variables=["context", "question"],
template=self.prompt_template
)
# 创建 RAG 链
self.rag_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": self.prompt}
)
def query(self, question):
"""处理用户查询"""
response = self.rag_chain.invoke({"query": question})
return {
"question": question,
"answer": response["result"],
"retrieved_docs": response.get("source_documents", [])
}
# 使用示例
if __name__ == "__main__":
# 初始化 RAG 系统
rag_system = RAGSystem(index_path="rag_knowledge_base_index.faiss")
# 测试查询
queries = [
"Which operating systems are supported by the product?",
"How to contact technical support?",
"What are the installation requirements?"
]
for query in queries:
result = rag_system.query(query)
print(f"Question: {result['question']}")
print(f"Answer: {result['answer']}")
print("Retrieved Documents:")
for doc in result['retrieved_docs']:
print(f"- {doc.page_content[:100]}...")
print("-" * 50)
代码说明:
RAGSystem
类封装了所有 RAG 功能,包括初始化、链构建和查询处理。query
方法返回查询结果和检索到的文档,便于调试。- 主程序测试了多个查询,展示了系统的通用性。
示例输出:
Question: Which operating systems are supported by the product?
Answer: Our product supports Windows 10/11, macOS 12+, and Ubuntu 20.04+.
Retrieved Documents:
- Q: 产品支持哪些操作系统? A: 我们的产品支持 Windows 10/11、macOS 12+ 和 Ubuntu 20.04+。...
- 本产品是一款高性能的智能设备,支持多种功能,包括实时数据处理和远程控制。...
...
--------------------------------------------------
Question: How to contact technical support?
Answer: You can contact technical support via email at support@company.com or by phone at 123-456-7890.
Retrieved Documents:
- Q: 如何联系技术支持? A: 您可以通过邮箱 support@company.com 或电话 123-456-7890 联系我们。...
...
--------------------------------------------------
3.7 常见问题与解决方案
在实现 RAG 核心逻辑时,Java 程序员可能遇到以下问题:
-
生成模型输出不准确:
- 问题:
flan-t5-base
可能生成不自然的回答或忽略上下文。 - 解决方案:使用更大的模型(如
flan-t5-large
)或优化提示模板,确保上下文清晰。
- 问题:
-
检索文档不相关:
-
问题:检索到的文档与查询无关。
-
解决方案:
- 检查嵌入模型的质量,考虑更换为
all-mpnet-base-v2
。 - 调整文档分割粒度,确保片段长度适中(第二章的
max_length
参数)。 - 增加
top_k
值,获取更多候选文档。
- 检查嵌入模型的质量,考虑更换为
-
-
性能瓶颈:
-
问题:大型知识库或复杂模型导致查询速度慢。
-
解决方案:
- 使用 FAISS 的
IndexIVFFlat
或IndexHNSW
索引。 - 在 GPU 上运行生成模型。
- 缓存频繁查询的结果。
- 使用 FAISS 的
-
-
提示模板过于复杂:
- 问题:复杂的提示可能导致模型生成不一致的回答。
- 解决方案:简化提示,明确任务目标,避免冗余指令。
3.8 本章总结
本章详细讲解了 RAG 系统的核心逻辑实现,包括查询处理、检索、上下文增强和生成。你现在应该掌握了以下内容:
- 使用
langchain
框架整合嵌入模型、向量数据库和生成模型。 - 实现完整的 RAG 流程,基于
flan-t5-base
和 FAISS。 - 优化 RAG 系统的提示模板、检索和生成性能。
- 常见问题的解决方案,如生成质量、检索相关性和性能瓶颈。
在下一章,我们将使用 FastAPI 封装 RAG 系统为 RESTful API,暴露查询和检索功能,为与 Spring Boot 项目的集成做好准备。